In [None]:
import os
import yaml
import json
import argparse
import threading
import numpy as np
from datetime import datetime
from typing import List
from langchain.prompts import PromptTemplate
from asn.env.environment import Environment
from asn.data.data import Data
from asn.agent.agent import Agent, LLMAgent
from asn.agent.action import Act
from asn.llm.llm import LLMManager
from asn.llm.prompt import PROMPT_EVALUATE_ACT, PROMPT_EVALUATE_POST, PROMPT_EMOTION_CLASSIFICATION
from asn.utils.time import *
from asn.utils.logger import get_logger, set_logger

import warnings
warnings.filterwarnings("ignore")

# set_logger(log_folder="log_eval")
# args = argparse.ArgumentParser()
# args.add_argument("--config_path", "-c", type=str, default="config/example.yaml")
# args.add_argument("--data_path", "-d", type=str, default="dataset/example.json")
# args = args.parse_args()
args = argparse.Namespace(config_path="config/example.yaml", data_path="dataset/example.json")
conf_path = args.config_path
with open(conf_path, "r") as f:
    conf = yaml.load(f, Loader=yaml.FullLoader)
for key, value in args.__dict__.items():
    conf[key] = value

# print settings
settings = "Settings:\n"
for key, value in conf.items():
    settings += f"{key}: {value}\n"
get_logger().info(settings)
get_logger().debug(settings)
print(settings)

# Random seed
if "seed" in conf:
    seed = conf["seed"]
else:
    seed = 0
import random
random.seed(conf["seed"])
import numpy as np
np.random.seed(conf["seed"])

# Set LLMManager
LLMManager.set_manager(conf["llm"])

# Load data and environment
save_path = conf["save_path"] + f"/model_{conf['time_sim_end'].replace(' ', '_')}"
data = Data.load_from_data(conf["data_path"])
with open(os.path.join(save_path, f"model.json"), "r") as f:
    model_dict = json.load(f)
env = Environment.load_from_dict(model_dict["env"])

for user in env.users:
    user.mastodon_info = data.get_meta(user.id, "mastodon_info")


In [None]:
print(len(env.log))

In [None]:
import threading
from mastodon import Mastodon
from asn.utils.mastodon import *
from asn.llm.prompt import Prompts
from multiavatar.multiavatar import multiavatar
from cairosvg import svg2png
import time


# 统一env.logs转发的type
for log in env.log:
    if log["act"]["type"] == "retweet" or log["act"]["type"] == "repost" or log["act"]["type"] == "share":
        log["act"]["type"] = "repost"

# 翻译用户profile
# print("翻译用户profile")
# def translate_profile(user):
#     llm = LLMManager.get_llm()
#     characteristics = user.agent.profile.characteristics
#     characteristics = characteristics[characteristics.find("characteristics:")+len("characteristics:"):]
#     prompt = Prompts.post_translation.format(post_text=characteristics)
#     user.agent.profile.characteristics_translated = llm._call(prompt)
# threads = []
# for user in env.users:
#     t = threading.Thread(target=translate_profile, args=(user,))
#     threads.append(t)
#     t.start()
# for t in threads:
#     t.join()
for user in env.users:
    print(user.agent.profile.characteristics_translated)

# 登录每一个用户
for user in env.users:
    user.access_token = user.mastodon_info["access_token"]
    user.mastodon = Mastodon(
        access_token = user.mastodon_info["access_token"],
        api_base_url = user.mastodon_info["api_base_url"]
    )
    mastodon = Mastodon(
        access_token = user.mastodon_info["access_token"],
        api_base_url = user.mastodon_info["api_base_url"]
    )

# 为每一个用户更新profile
for user in env.users:
    # note
    user.mastodon.account_update_credentials(note=user.agent.profile.characteristics_translated)

# 翻译帖子
print("翻译帖子")
def translate_post(message):
    llm = LLMManager.get_llm()
    prompt = Prompts.post_translation.format(post_text=message.text)
    if message.text_translated is not None:
        return
    try:
        message.text_translated = llm._call(prompt)
    except Exception as e:
        message.text_translated = message.text
threads = []
for message in env.messages:
    t = threading.Thread(target=translate_post, args=(message,))
    threads.append(t)
    t.start()
for t in threads:
    t.join()

# 保存model
save_path = conf["save_path"] + f"/model_{conf['time_sim_end'].replace(' ', '_')}"
with open(os.path.join(save_path, f"model.json"), "w") as f:
    model_dict = {
        "env": env.save_to_dict(save_path)
    }
    json.dump(model_dict, f, indent=4)

time_sim_begin = conf["time_init_begin"]
time_sim_end = conf["time_sim_end"]
time_intv = conf["interval"]
time_step = time_sim_begin
cnt_log = 0

In [None]:
# 每次发生异常后，从上次的log开始继续
logs = env.log[cnt_log:]
# 遍历&处理每一条log， 每一百条log睡眠一次
while time_step < time_sim_end:
    time_next = add_interval(time_step, time_intv)
    log_step = [log for log in logs if time_step <= log["timestamp"] < time_next]
    print(f"Time: {time_step}")
    for log in log_step:
        user = env.get_user_by_id(log["user"])
        message = env.get_message_by_id(log["message"])
        message = env.get_message_by_id(message.origin_id())
        act = Act.load_from_dict(log["act"])
        if act.type != "read": # 过滤掉read的log
            print(f"User: {user.id}, Act: {act.type}, Message: {message.text_translated}")
        if act.type == "post":
            # 在time_step和time_next之间随机一个时间
            time_post = generate_random_between(time_step, time_next)
            status = post_status_with_time(user.access_token, message.text_translated, time_post)
            message.mastodon_info = {
                "id": status["id"],
                "content": status["content"]
            }
        elif act.type == "like":
            user.mastodon.status_favourite(message.mastodon_info["id"])
        elif act.type == "repost":
            user.mastodon.status_reblog(message.mastodon_info["id"])
        cnt_log += 1
        if cnt_log % 100 == 0:
            time.sleep(10)
    time_step = time_next

In [None]:
msgid2likes = {msg.id: set() for msg in env.messages}
msgid2retweets = {msg.id: set() for msg in env.messages}
msgid2reads = {msg.id: set() for msg in env.messages}
# 遍历&处理每一条log
time_sim_begin = conf["time_sim_begin"]
time_sim_end = conf["time_sim_end"]
time_intv = conf["interval"]
time_step = time_sim_begin
while time_step < time_sim_end:
    time_next = add_interval(time_step, time_intv)
    log_step = [log for log in env.log if time_step <= log["timestamp"] < time_next]
    print(f"Time: {time_step}")
    for log in log_step:
        user = env.get_user_by_id(log["user"])
        message = env.get_message_by_id(log["message"])
        message = env.get_message_by_id(message.origin_id())
        act = Act.load_from_dict(log["act"])
        if act.type != "read": # 过滤掉read的log
            print(f"User: {user.id}, Act: {act.type}, Message: {message.text_translated}")
        if act.type == "read":
            msgid2reads[message.id].add(user.id)
        if act.type == "like":
            msgid2likes[message.id].add(user.id)
        elif act.type == "repost":
            msgid2retweets[message.id].add(user.id)
    time_step = time_next

    # 保存mid2likes和mid2retweets在文件中
    with open("test.txt", "w") as f:
        for message in env.messages:
            if message.type == "post":
                print(f"Time: {message.timestamp}. Reads: {len(msgid2reads[message.id])}. Likes: {len(msgid2likes[message.id])},{message.liked_by}. Retweets: {len(msgid2retweets[message.id])},{message.reposted_by}. Message: {message.text}", file=f)
            else:
                print(message.type)


In [None]:
print(len([log for log in env.log if log["act"]["type"] == "post" or log["act"]["type"] == "repost"]))
print(len([log for log in env.log if log["act"]["type"] == "like"]))
time_sim_begin = conf["time_init_begin"]
time_sim_end = conf["time_sim_end"]
time_intv = conf["time_intv"]
time_step = time_sim_begin
while time_step < time_sim_end:
    time_next = add_interval(time_step, time_intv)
    log_step = [log for log in env.log if time_step <= log["timestamp"] < time_next]
    print(f"Time: {time_step}")
    print(len([log for log in log_step if log["act"]["type"] == "read"]))
    print(len([log for log in log_step if log["act"]["type"] == "like"]))
    print(len([log for log in log_step if log["act"]["type"] == "repost"]))
    print(len([log for log in log_step if log["act"]["type"] == "post"]))
    time_step = time_next