In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
import pandas as pd
from tqdm import tqdm
from nltk.tokenize import word_tokenize
from collections import Counter
import numpy as np
import random
import logging
import os
import torch.optim as optim
import torch.distributed as dist
from torch.utils.data import DataLoader
from pathlib import Path



In [None]:
from dataclasses import dataclass
from typing import Optional

@dataclass
class Args:
    nGPU: int = 1
    seed: int = 0
    prepare: bool = True
    mode: str = "train"
    train_data_dir: str = "/content/drive/MyDrive/Colab Notebooks/NewsRecommendation/data/MINDsmall_train"
    test_data_dir: str = "/content/drive/MyDrive/Colab Notebooks/NewsRecommendation/data/MINDsmall_dev"
    custom_abstract_dir: str = ""
    model_dir: str = '/content/model'
    batch_size: int = 32
    npratio: int = 4
    enable_gpu: bool = True
    filter_num: int = 3
    log_steps: int = 100
    epochs: int = 5
    lr: float = 0.0003
    num_words_title: int = 20
    num_words_abstract: int = 50
    user_log_length: int = 50
    word_embedding_dim: int = 300
    glove_embedding_path: str = '/content/drive/MyDrive/Colab Notebooks/NewsRecommendation/data/glove.840B.300d.txt'
    freeze_embedding: bool = False
    news_dim: int = 400
    news_query_vector_dim: int = 200
    user_query_vector_dim: int = 200
    num_attention_heads: int = 20
    user_log_mask: bool = False
    drop_rate: float = 0.2
    save_steps: int = 10000
    start_epoch: int = 0
    load_ckpt_name: Optional[str] = None
    use_category: bool = True
    use_subcategory: bool = True
    use_abstract: bool = True
    use_custom_abstract: bool = False
    category_emb_dim: int = 100

def parse_args():
  return Args()


**Dataset.py**

In [None]:
from torch.utils.data import IterableDataset, Dataset
import numpy as np
import random


class DatasetTrain(IterableDataset):
    def __init__(self, filename, news_index, news_combined, args):
        super(DatasetTrain).__init__()
        self.filename = filename
        self.news_index = news_index
        self.news_combined = news_combined
        self.args = args

    def trans_to_nindex(self, nids):
        return [self.news_index[i] if i in self.news_index else 0 for i in nids]

    def pad_to_fix_len(self, x, fix_length, padding_front=True, padding_value=0):
        if padding_front:
            pad_x = [padding_value] * (fix_length - len(x)) + x[-fix_length:]
            mask = [0] * (fix_length - len(x)) + [1] * min(fix_length, len(x))
        else:
            pad_x = x[-fix_length:] + [padding_value] * (fix_length - len(x))
            mask = [1] * min(fix_length, len(x)) + [0] * (fix_length - len(x))
        return pad_x, np.array(mask, dtype='float32')

    def line_mapper(self, line):
        line = line.strip().split('\t')
        click_docs = line[3].split()
        sess_pos = line[4].split()
        sess_neg = line[5].split()

        click_docs, log_mask = self.pad_to_fix_len(self.trans_to_nindex(click_docs), self.args.user_log_length)
        user_feature = self.news_combined[click_docs]

        pos = self.trans_to_nindex(sess_pos)
        neg = self.trans_to_nindex(sess_neg)

        label = random.randint(0, self.args.npratio)
        sample_news = neg[:label] + pos + neg[label:]
        news_feature = self.news_combined[sample_news]

        return user_feature, log_mask, news_feature, label

    def __iter__(self):
        file_iter = open(self.filename)
        return map(self.line_mapper, file_iter)


class DatasetTest(DatasetTrain):
    def __init__(self, filename, news_index, news_scoring, args):
        super(DatasetTrain).__init__()
        self.filename = filename
        self.news_index = news_index
        self.news_scoring = news_scoring
        self.args = args

    def line_mapper(self, line):
        line = line.strip().split('\t')
        click_docs = line[3].split()
        click_docs, log_mask = self.pad_to_fix_len(self.trans_to_nindex(click_docs), self.args.user_log_length)
        user_feature = self.news_scoring[click_docs]

        candidate_news = self.trans_to_nindex([i.split('-')[0] for i in line[4].split()])
        labels = np.array([int(i.split('-')[1]) for i in line[4].split()])
        news_feature = self.news_scoring[candidate_news]

        return user_feature, log_mask, news_feature, labels

    def __iter__(self):
        file_iter = open(self.filename)
        return map(self.line_mapper, file_iter)


class NewsDataset(Dataset):
    def __init__(self, data):
        self.data = data

    def __getitem__(self, idx):
        return self.data[idx]

    def __len__(self):
        return self.data.shape[0]


**Metric.py**

In [None]:
from sklearn.metrics import roc_auc_score
import numpy as np


def dcg_score(y_true, y_score, k=10):
    order = np.argsort(y_score)[::-1]
    y_true = np.take(y_true, order[:k])
    gains = 2**y_true - 1
    discounts = np.log2(np.arange(len(y_true)) + 2)
    return np.sum(gains / discounts)


def ndcg_score(y_true, y_score, k=10):
    best = dcg_score(y_true, y_true, k)
    actual = dcg_score(y_true, y_score, k)
    return actual / best


def mrr_score(y_true, y_score):
    order = np.argsort(y_score)[::-1]
    y_true = np.take(y_true, order)
    rr_score = y_true / (np.arange(len(y_true)) + 1)
    return np.sum(rr_score) / np.sum(y_true)


def ctr_score(y_true, y_score, k=1):
    order = np.argsort(y_score)[::-1]
    y_true = np.take(y_true, order[:k])
    return np.mean(y_true)

def acc(y_true, y_hat):
    y_hat = torch.argmax(y_hat, dim=-1)
    tot = y_true.shape[0]
    hit = torch.sum(y_true == y_hat)
    return hit.data.float() * 1.0 / tot



**Ultis.py**

In [None]:
import logging
import argparse
import sys

def setuplogger():
    root = logging.getLogger()
    root.setLevel(logging.INFO)
    handler = logging.StreamHandler(sys.stdout)
    handler.setLevel(logging.INFO)
    formatter = logging.Formatter("[%(levelname)s %(asctime)s] %(message)s")
    handler.setFormatter(formatter)
    root.addHandler(handler)


def dump_args(args):
    for arg in dir(args):
        if not arg.startswith("_"):
            logging.info(f"args[{arg}]={getattr(args, arg)}")

def load_matrix(embedding_file_path, word_dict, word_embedding_dim):
    embedding_matrix = np.zeros(shape=(len(word_dict) + 1, word_embedding_dim))
    have_word = []
    if embedding_file_path is not None:
        with open(embedding_file_path, 'rb') as f:
            while True:
                line = f.readline()
                if len(line) == 0:
                    break
                line = line.split()
                word = line[0].decode()
                if word in word_dict:
                    index = word_dict[word]
                    tp = [float(x) for x in line[1:]]
                    embedding_matrix[index] = np.array(tp)
                    have_word.append(word)
    return embedding_matrix, have_word


def get_checkpoint(directory, ckpt_name):
    ckpt_path = os.path.join(directory, ckpt_name)
    if os.path.exists(ckpt_path):
        return ckpt_path
    else:
        return None


**Model_ultis.py**

In [None]:
from torch import nn
class AttentionPooling(nn.Module):
    def __init__(self, emb_size, hidden_size):
        super(AttentionPooling, self).__init__()
        self.att_fc1 = nn.Linear(emb_size, hidden_size)
        self.att_fc2 = nn.Linear(hidden_size, 1)

    def forward(self, x, attn_mask=None):
        """
        Args:
            x: batch_size, candidate_size, emb_dim
            attn_mask: batch_size, candidate_size
        Returns:
            (shape) batch_size, emb_dim
        """
        e = self.att_fc1(x)
        e = nn.Tanh()(e)
        alpha = self.att_fc2(e)
        alpha = torch.exp(alpha)

        if attn_mask is not None:
            alpha = alpha * attn_mask.unsqueeze(2)

        alpha = alpha / (torch.sum(alpha, dim=1, keepdim=True) + 1e-8)
        x = torch.bmm(x.permute(0, 2, 1), alpha).squeeze(dim=-1)
        return x


# NAML.py

In [None]:
import torch
from torch import nn
import torch.nn.functional as F


class NewsEncoder(nn.Module):
    def __init__(self, args, embedding_matrix, num_category, num_subcategory):
        super(NewsEncoder, self).__init__()
        self.embedding_matrix = embedding_matrix
        self.drop_rate = args.drop_rate
        self.num_words_title = args.num_words_title
        self.use_category = args.use_category
        self.use_subcategory = args.use_subcategory
        self.use_abstract = args.use_abstract
        self.num_words_abstract = args.num_words_abstract
        if args.use_category:
            self.category_emb = nn.Embedding(num_category + 1, args.category_emb_dim, padding_idx=0)
            self.category_dense = nn.Linear(args.category_emb_dim, args.news_dim)
        if args.use_subcategory:
            self.subcategory_emb = nn.Embedding(num_subcategory + 1, args.category_emb_dim, padding_idx=0)
            self.subcategory_dense = nn.Linear(args.category_emb_dim, args.news_dim)
        if args.use_category or args.use_subcategory:
            self.final_attn = AttentionPooling(args.news_dim, args.news_query_vector_dim)
        self.cnn = nn.Conv1d(
            in_channels=args.word_embedding_dim,
            out_channels=args.news_dim,
            kernel_size=3,
            padding=1
        )
        self.attn = AttentionPooling(args.news_dim, args.news_query_vector_dim)

        if args.use_abstract:
            self.abstract_cnn = nn.Conv1d(
                in_channels=args.word_embedding_dim,
                out_channels=args.news_dim,
                kernel_size=3,
                padding=1
            )
            self.abstract_attn = AttentionPooling(args.news_dim, args.news_query_vector_dim)


    def forward(self, x, mask=None):
        '''
            x: batch_size, word_num
            mask: batch_size, word_num
        '''
        title = torch.narrow(x, -1, 0, self.num_words_title).long()
        word_vecs = F.dropout(self.embedding_matrix(title),
                              p=self.drop_rate,
                              training=self.training)
        context_word_vecs = self.cnn(word_vecs.transpose(1, 2)).transpose(1, 2)
        title_vecs = self.attn(context_word_vecs, mask)
        all_vecs = [title_vecs]

        start = self.num_words_title
        if self.use_category:
            category = torch.narrow(x, -1, start, 1).squeeze(dim=-1).long()
            category_vecs = self.category_dense(self.category_emb(category))
            all_vecs.append(category_vecs)
            start += 1
        if self.use_subcategory:
            subcategory = torch.narrow(x, -1, start, 1).squeeze(dim=-1).long()
            subcategory_vecs = self.subcategory_dense(self.subcategory_emb(subcategory))
            all_vecs.append(subcategory_vecs)

        if self.use_abstract:
            abstract = torch.narrow(x, -1, start, self.num_words_abstract).long()
            abstract_word_vecs = F.dropout(self.embedding_matrix(abstract),
                                           p=self.drop_rate,
                                           training=self.training)
            abstract_context_word_vecs = self.abstract_cnn(abstract_word_vecs.transpose(1, 2)).transpose(1, 2)
            abstract_vecs = self.abstract_attn(abstract_context_word_vecs, mask)
            all_vecs.append(abstract_vecs)

        if len(all_vecs) == 1:
            news_vecs = all_vecs[0]
        else:
            all_vecs = torch.stack(all_vecs, dim=1)
            news_vecs = self.final_attn(all_vecs)
        return news_vecs

class UserEncoder(nn.Module):
    def __init__(self, args):
        super(UserEncoder, self).__init__()
        self.args = args
        self.attn = AttentionPooling(args.news_dim, args.user_query_vector_dim)
        self.pad_doc = nn.Parameter(torch.empty(1, args.news_dim).uniform_(-1, 1)).type(torch.FloatTensor)

    def forward(self, news_vecs, log_mask=None):
        '''
            news_vecs: batch_size, history_num, news_dim
            log_mask: batch_size, history_num
        '''
        bz = news_vecs.shape[0]
        if self.args.user_log_mask:
            user_vec = self.attn(news_vecs, log_mask)
        else:
            padding_doc = self.pad_doc.unsqueeze(dim=0).expand(bz, self.args.user_log_length, -1)
            news_vecs = news_vecs * log_mask.unsqueeze(dim=-1) + padding_doc * (1 - log_mask.unsqueeze(dim=-1))
            user_vec = self.attn(news_vecs)
        return user_vec


class Model(torch.nn.Module):
    def __init__(self, args, embedding_matrix, num_category, num_subcategory, **kwargs):
        super(Model, self).__init__()
        self.args = args
        pretrained_word_embedding = torch.from_numpy(embedding_matrix).float()
        word_embedding = nn.Embedding.from_pretrained(pretrained_word_embedding,
                                                      freeze=args.freeze_embedding,
                                                      padding_idx=0)

        self.news_encoder = NewsEncoder(args, word_embedding, num_category, num_subcategory)
        self.user_encoder = UserEncoder(args)
        self.loss_fn = nn.CrossEntropyLoss()

    def forward(self, history, history_mask, candidate, label):
        '''
            history: batch_size, history_length, num_word_title
            history_mask: batch_size, history_length
            candidate: batch_size, 1+K, num_word_title
            label: batch_size, 1+K
        '''
        num_words = history.shape[-1]
        candidate_news = candidate.reshape(-1, num_words)
        candidate_news_vecs = self.news_encoder(candidate_news).reshape(-1, 1 + self.args.npratio, self.args.news_dim)

        history_news = history.reshape(-1, num_words)
        history_news_vecs = self.news_encoder(history_news).reshape(-1, self.args.user_log_length, self.args.news_dim)

        user_vec = self.user_encoder(history_news_vecs, history_mask)
        score = torch.bmm(candidate_news_vecs, user_vec.unsqueeze(dim=-1)).squeeze(dim=-1)
        loss = self.loss_fn(score, label)
        return loss, score


**preprocess.py**

In [None]:
from collections import Counter
from tqdm import tqdm
import numpy as np
from nltk.tokenize import word_tokenize


def update_dict(dict, key, value=None):
    if key not in dict:
        if value is None:
            dict[key] = len(dict) + 1
        else:
            dict[key] = value


def read_custom_abstract(news_file, custom_abstract_dict):
    news = {}
    news_index = {}
    category_dict = {}
    subcategory_dict = {}
    word_cnt = {}

    with open(news_file, 'r', encoding='utf-8') as f:
        for line in f:
            splited = line.strip('\n').split('\t')
            doc_id, category, subcategory, title, abstract, url, entity_title, entity_abstract = splited
            if doc_id in custom_abstract_dict:
                abstract = custom_abstract_dict[doc_id]
            news[doc_id] = [title.split(' '), category, subcategory, abstract.split(' ')]
            news_index[doc_id] = len(news_index) + 1
            for word in title.split(' '):
                if word not in word_cnt:
                    word_cnt[word] = 0
                word_cnt[word] += 1
            for word in abstract.split(' '):
                if word not in word_cnt:
                    word_cnt[word] = 0
                word_cnt[word] += 1
            if category not in category_dict:
                category_dict[category] = len(category_dict) + 1
            if subcategory not in subcategory_dict:
                subcategory_dict[subcategory] = len(subcategory_dict) + 1

    return news, news_index, category_dict, subcategory_dict, word_cnt

def read_news(news_path, args, mode='train'):
    news = {}
    category_dict = {}
    subcategory_dict = {}
    news_index = {}
    word_cnt = Counter()

    with open(news_path, 'r', encoding='utf-8') as f:
        for line in tqdm(f):
            splited = line.strip('\n').split('\t')
            doc_id, category, subcategory, title, abstract, url, _, _ = splited
            update_dict(news_index, doc_id)

            title = title.lower()
            title = word_tokenize(title, language='english', preserve_line=True)

            update_dict(news, doc_id, [title, category, subcategory, abstract])
            if mode == 'train':
                if args.use_category:
                    update_dict(category_dict, category)
                if args.use_subcategory:
                    update_dict(subcategory_dict, subcategory)
                word_cnt.update(title)

    if mode == 'train':
        word = [k for k, v in word_cnt.items() if v > args.filter_num]
        word_dict = {k: v for k, v in zip(word, range(1, len(word) + 1))}
        return news, news_index, category_dict, subcategory_dict, word_dict
    elif mode == 'test':
        return news, news_index
    else:
        assert False, 'Wrong mode!'


def get_doc_input(news, news_index, category_dict, subcategory_dict, word_dict, args):
    news_num = len(news) + 1
    news_title = np.zeros((news_num, args.num_words_title), dtype='int32')
    news_category = np.zeros((news_num, 1), dtype='int32') if args.use_category else None
    news_subcategory = np.zeros((news_num, 1), dtype='int32') if args.use_subcategory else None
    news_abstract = np.zeros((news_num, args.num_words_abstract), dtype='int32') if args.use_abstract else None

    for key in tqdm(news):
        title, category, subcategory, abstract = news[key]
        doc_index = news_index[key]

        for word_id in range(min(args.num_words_title, len(title))):
            if title[word_id] in word_dict:
                news_title[doc_index, word_id] = word_dict[title[word_id]]

        if args.use_category:
            news_category[doc_index, 0] = category_dict[category] if category in category_dict else 0
        if args.use_subcategory:
            news_subcategory[doc_index, 0] = subcategory_dict[subcategory] if subcategory in subcategory_dict else 0
        if args.use_abstract:
            for word_id in range(min(args.num_words_abstract, len(abstract))):
                if abstract[word_id] in word_dict:
                    news_abstract[doc_index, word_id] = word_dict[abstract[word_id]]

    return news_title, news_category, news_subcategory, news_abstract

**prepare_data.py**

In [None]:
import os
from tqdm import tqdm
import random
import logging


def get_sample(all_elements, num_sample):
    if num_sample > len(all_elements):
        return random.sample(all_elements * (num_sample // len(all_elements) + 1), num_sample)
    else:
        return random.sample(all_elements, num_sample)


def prepare_training_data(train_data_dir, nGPU, npratio, seed):
    random.seed(seed)
    behaviors = []

    behavior_file_path = os.path.join(train_data_dir, 'behaviors.tsv')
    with open(behavior_file_path, 'r', encoding='utf-8') as f:
        for line in tqdm(f):
            iid, uid, time, history, imp = line.strip().split('\t')
            impressions = [x.split('-') for x in imp.split(' ')]
            pos, neg = [], []
            for news_ID, label in impressions:
                if label == '0':
                    neg.append(news_ID)
                elif label == '1':
                    pos.append(news_ID)
            if len(pos) == 0 or len(neg) == 0:
                continue
            for pos_id in pos:
                neg_candidate = get_sample(neg, npratio)
                neg_str = ' '.join(neg_candidate)
                new_line = '\t'.join([iid, uid, time, history, pos_id, neg_str]) + '\n'
                behaviors.append(new_line)

    random.shuffle(behaviors)

    behaviors_per_file = [[] for _ in range(nGPU)]
    for i, line in enumerate(behaviors):
        behaviors_per_file[i % nGPU].append(line)

    logging.info('Writing files...')
    for i in range(nGPU):
        processed_file_path = os.path.join(train_data_dir, f'behaviors_np{npratio}_{i}.tsv')
        with open(processed_file_path, 'w') as f:
            f.writelines(behaviors_per_file[i])

    return len(behaviors)


def prepare_testing_data(test_data_dir, nGPU):
    behaviors = [[] for _ in range(nGPU)]

    behavior_file_path = os.path.join(test_data_dir, 'behaviors.tsv')
    with open(behavior_file_path, 'r', encoding='utf-8') as f:
        for i, line in enumerate(tqdm(f)):
            behaviors[i % nGPU].append(line)

    logging.info('Writing files...')
    for i in range(nGPU):
        processed_file_path = os.path.join(test_data_dir, f'behaviors_{i}.tsv')
        with open(processed_file_path, 'w') as f:
            f.writelines(behaviors[i])

    return sum([len(x) for x in behaviors])


In [None]:
def train(rank, args):

    is_distributed = False
    torch.cuda.set_device(rank)


    if (args.use_custom_abstract):
      custom_abstract_df = pd.read_csv(args.custom_abstract_dir)
      custom_abstract_dict = custom_abstract_df.set_index('news_id')['abstract'].to_dict()
      news, news_index, category_dict, subcategory_dict, word_cnt = read_custom_abstract(
          os.path.join(args.train_data_dir, 'news.tsv'), custom_abstract_dict)
    else:
      news, news_index, category_dict, subcategory_dict, word_dict = read_news(
          os.path.join(args.train_data_dir, 'news.tsv'), args, mode='train')

    news_title, news_category, news_subcategory, news_abstract = get_doc_input(
        news, news_index, category_dict, subcategory_dict, word_dict, args)
    news_combined = np.concatenate([x for x in [news_title, news_category, news_subcategory, news_abstract] if x is not None], axis=-1)

    if rank == 0:
        logging.info('Initializing word embedding matrix...')

    embedding_matrix, have_word = load_matrix(args.glove_embedding_path,
                                                    word_dict,
                                                    args.word_embedding_dim)
    if rank == 0:
        logging.info(f'Word dict length: {len(word_dict)}')
        logging.info(f'Have words: {len(have_word)}')
        logging.info(f'Missing rate: {(len(word_dict) - len(have_word)) / len(word_dict)}')

    model = Model(args, embedding_matrix, len(category_dict), len(subcategory_dict))

    if args.load_ckpt_name is not None:
        ckpt_path = get_checkpoint(args.model_dir, args.load_ckpt_name)
        checkpoint = torch.load(ckpt_path, map_location='cpu')
        model.load_state_dict(checkpoint['model_state_dict'])
        logging.info(f"Model loaded from {ckpt_path}.")

    optimizer = optim.Adam(model.parameters(), lr=args.lr)

    if args.enable_gpu:
        model = model.cuda(rank)

    if is_distributed:
        model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[rank])

    # if rank == 0:
    #     print(model)
    #     for name, param in model.named_parameters():
    #         print(name, param.requires_grad)

    data_file_path = os.path.join(args.train_data_dir, f'behaviors_np{args.npratio}_{rank}.tsv')

    dataset = DatasetTrain(data_file_path, news_index, news_combined, args)
    dataloader = DataLoader(dataset, batch_size=args.batch_size)

    logging.info('Training...')
    for ep in range(args.start_epoch, args.epochs):
        loss = 0.0
        accuary = 0.0
        for cnt, (log_ids, log_mask, input_ids, targets) in enumerate(dataloader):
            if args.enable_gpu:
                log_ids = log_ids.cuda(rank, non_blocking=True)
                log_mask = log_mask.cuda(rank, non_blocking=True)
                input_ids = input_ids.cuda(rank, non_blocking=True)
                targets = targets.cuda(rank, non_blocking=True)

            bz_loss, y_hat = model(log_ids, log_mask, input_ids, targets)
            loss += bz_loss.data.float()
            accuary += acc(targets, y_hat)
            optimizer.zero_grad()
            bz_loss.backward()
            optimizer.step()

            if cnt % args.log_steps == 0:
                logging.info(
                    '[{}] Ed: {}, train_loss: {:.5f}, acc: {:.5f}'.format(
                        rank, cnt * args.batch_size, loss.data / cnt, accuary / cnt)
                )

            if rank == 0 and     cnt != 0 and cnt % args.save_steps == 0:
                ckpt_path = os.path.join(args.model_dir, f'epoch-{ep+1}-{cnt}.pt')
                torch.save(
                    {
                        'model_state_dict':
                            {'.'.join(k.split('.')[1:]): v for k, v in model.state_dict().items()}
                            if is_distributed else model.state_dict(),
                        'category_dict': category_dict,
                        'word_dict': word_dict,
                        'subcategory_dict': subcategory_dict
                    }, ckpt_path)
                logging.info(f"Model saved to {ckpt_path}.")

        logging.info('Training finish.')

        if rank == 0:
            ckpt_path = os.path.join(args.model_dir, f'epoch-{ep+1}.pt')
            torch.save(
                {
                    'model_state_dict':
                        {'.'.join(k.split('.')[1:]): v for k, v in model.state_dict().items()}
                        if is_distributed else model.state_dict(),
                    'category_dict': category_dict,
                    'subcategory_dict': subcategory_dict,
                    'word_dict': word_dict,
                }, ckpt_path)
            logging.info(f"Model saved to {ckpt_path}.")



In [None]:

    import subprocess
    setuplogger()
    args = parse_args()
    dump_args(args)
    random.seed(args.seed)
    torch.manual_seed(args.seed)

    Path(args.model_dir).mkdir(parents=True, exist_ok=True)




INFO:root:args[batch_size]=32


[INFO 2024-12-03 18:48:52,860] args[batch_size]=32


INFO:root:args[category_emb_dim]=100


[INFO 2024-12-03 18:48:52,864] args[category_emb_dim]=100


INFO:root:args[drop_rate]=0.2


[INFO 2024-12-03 18:48:52,865] args[drop_rate]=0.2


INFO:root:args[enable_gpu]=True


[INFO 2024-12-03 18:48:52,870] args[enable_gpu]=True


INFO:root:args[epochs]=5


[INFO 2024-12-03 18:48:52,871] args[epochs]=5


INFO:root:args[filter_num]=3


[INFO 2024-12-03 18:48:52,876] args[filter_num]=3


INFO:root:args[freeze_embedding]=False


[INFO 2024-12-03 18:48:52,877] args[freeze_embedding]=False


INFO:root:args[glove_embedding_path]=/content/drive/MyDrive/Colab Notebooks/NewsRecommendation/data/glove.840B.300d.txt


[INFO 2024-12-03 18:48:52,879] args[glove_embedding_path]=/content/drive/MyDrive/Colab Notebooks/NewsRecommendation/data/glove.840B.300d.txt


INFO:root:args[load_ckpt_name]=None


[INFO 2024-12-03 18:48:52,880] args[load_ckpt_name]=None


INFO:root:args[log_steps]=100


[INFO 2024-12-03 18:48:52,885] args[log_steps]=100


INFO:root:args[lr]=0.0003


[INFO 2024-12-03 18:48:52,888] args[lr]=0.0003


INFO:root:args[mode]=train


[INFO 2024-12-03 18:48:52,890] args[mode]=train


INFO:root:args[model_dir]=/content/model


[INFO 2024-12-03 18:48:52,892] args[model_dir]=/content/model


INFO:root:args[nGPU]=1


[INFO 2024-12-03 18:48:52,895] args[nGPU]=1


INFO:root:args[news_dim]=400


[INFO 2024-12-03 18:48:52,897] args[news_dim]=400


INFO:root:args[news_query_vector_dim]=200


[INFO 2024-12-03 18:48:52,900] args[news_query_vector_dim]=200


INFO:root:args[npratio]=4


[INFO 2024-12-03 18:48:52,903] args[npratio]=4


INFO:root:args[num_attention_heads]=20


[INFO 2024-12-03 18:48:52,904] args[num_attention_heads]=20


INFO:root:args[num_words_abstract]=50


[INFO 2024-12-03 18:48:52,906] args[num_words_abstract]=50


INFO:root:args[num_words_title]=20


[INFO 2024-12-03 18:48:52,908] args[num_words_title]=20


INFO:root:args[prepare]=True


[INFO 2024-12-03 18:48:52,912] args[prepare]=True


INFO:root:args[save_steps]=10000


[INFO 2024-12-03 18:48:52,914] args[save_steps]=10000


INFO:root:args[seed]=0


[INFO 2024-12-03 18:48:52,915] args[seed]=0


INFO:root:args[start_epoch]=0


[INFO 2024-12-03 18:48:52,918] args[start_epoch]=0


INFO:root:args[test_data_dir]=/content/drive/MyDrive/Colab Notebooks/NewsRecommendation/data/MINDsmall_dev


[INFO 2024-12-03 18:48:52,920] args[test_data_dir]=/content/drive/MyDrive/Colab Notebooks/NewsRecommendation/data/MINDsmall_dev


INFO:root:args[train_data_dir]=/content/drive/MyDrive/Colab Notebooks/NewsRecommendation/data/MINDsmall_train


[INFO 2024-12-03 18:48:52,924] args[train_data_dir]=/content/drive/MyDrive/Colab Notebooks/NewsRecommendation/data/MINDsmall_train


INFO:root:args[use_abstract]=True


[INFO 2024-12-03 18:48:52,925] args[use_abstract]=True


INFO:root:args[use_category]=True


[INFO 2024-12-03 18:48:52,927] args[use_category]=True


INFO:root:args[use_subcategory]=True


[INFO 2024-12-03 18:48:52,933] args[use_subcategory]=True


INFO:root:args[user_log_length]=50


[INFO 2024-12-03 18:48:52,934] args[user_log_length]=50


INFO:root:args[user_log_mask]=False


[INFO 2024-12-03 18:48:52,936] args[user_log_mask]=False


INFO:root:args[user_query_vector_dim]=200


[INFO 2024-12-03 18:48:52,938] args[user_query_vector_dim]=200


INFO:root:args[word_embedding_dim]=300


[INFO 2024-12-03 18:48:52,939] args[word_embedding_dim]=300


In [None]:
if 'train' in args.mode:
    if args.prepare:
        logging.info('Preparing training data...')
        total_sample_num = prepare_training_data(args.train_data_dir, args.nGPU, args.npratio, args.seed)
    else:
        total_sample_num = 0
        for i in range(args.nGPU):
            data_file_path = os.path.join(args.train_data_dir, f'behaviors_np{args.npratio}_{i}.tsv')
            print(data_file_path)
            if not os.path.exists(data_file_path):
                logging.error(f'Splited training data {data_file_path} for GPU {i} does not exist. Please set the parameter --prepare as True and rerun the code.')
                exit()
            result = subprocess.getoutput(f'wc -l {data_file_path}')
            total_sample_num += int(result.split(' ')[0])
        logging.info('Skip training data preparation.')
    logging.info(f'{total_sample_num} training samples, {total_sample_num // args.batch_size // args.nGPU} batches in total.')

    # train(0, args)

INFO:root:Preparing training data...


[INFO 2024-12-03 18:48:53,393] Preparing training data...


156965it [00:05, 29066.06it/s]
INFO:root:Writing files...


[INFO 2024-12-03 18:49:00,746] Writing files...


INFO:root:236344 training samples, 7385 batches in total.


[INFO 2024-12-03 18:49:03,584] 236344 training samples, 7385 batches in total.


In [None]:
    train(0, args)


51282it [00:06, 7424.45it/s]
100%|██████████| 51282/51282 [00:00<00:00, 58961.90it/s]
INFO:root:Initializing word embedding matrix...


[INFO 2024-12-03 18:49:16,050] Initializing word embedding matrix...


INFO:root:Word dict length: 12519


[INFO 2024-12-03 18:51:02,213] Word dict length: 12519


INFO:root:Have words: 11960


[INFO 2024-12-03 18:51:02,219] Have words: 11960


INFO:root:Missing rate: 0.0446521287642783


[INFO 2024-12-03 18:51:02,223] Missing rate: 0.0446521287642783


INFO:root:Training...


[INFO 2024-12-03 18:51:06,494] Training...


INFO:root:[0] Ed: 0, train_loss: inf, acc: inf


[INFO 2024-12-03 18:51:08,853] [0] Ed: 0, train_loss: inf, acc: inf


INFO:root:[0] Ed: 3200, train_loss: 1.59260, acc: 0.33531


[INFO 2024-12-03 18:51:20,312] [0] Ed: 3200, train_loss: 1.59260, acc: 0.33531


INFO:root:[0] Ed: 6400, train_loss: 1.54636, acc: 0.35016


[INFO 2024-12-03 18:51:31,804] [0] Ed: 6400, train_loss: 1.54636, acc: 0.35016


INFO:root:[0] Ed: 9600, train_loss: 1.52596, acc: 0.35688


[INFO 2024-12-03 18:51:43,412] [0] Ed: 9600, train_loss: 1.52596, acc: 0.35688


INFO:root:[0] Ed: 12800, train_loss: 1.50868, acc: 0.36977


[INFO 2024-12-03 18:51:55,119] [0] Ed: 12800, train_loss: 1.50868, acc: 0.36977


INFO:root:[0] Ed: 16000, train_loss: 1.49895, acc: 0.37600


[INFO 2024-12-03 18:52:06,935] [0] Ed: 16000, train_loss: 1.49895, acc: 0.37600


INFO:root:[0] Ed: 19200, train_loss: 1.49260, acc: 0.37932


[INFO 2024-12-03 18:52:18,813] [0] Ed: 19200, train_loss: 1.49260, acc: 0.37932


INFO:root:[0] Ed: 22400, train_loss: 1.48654, acc: 0.38295


[INFO 2024-12-03 18:52:30,772] [0] Ed: 22400, train_loss: 1.48654, acc: 0.38295


INFO:root:[0] Ed: 25600, train_loss: 1.48229, acc: 0.38633


[INFO 2024-12-03 18:52:42,808] [0] Ed: 25600, train_loss: 1.48229, acc: 0.38633


INFO:root:[0] Ed: 28800, train_loss: 1.47743, acc: 0.38979


[INFO 2024-12-03 18:52:54,897] [0] Ed: 28800, train_loss: 1.47743, acc: 0.38979


INFO:root:[0] Ed: 32000, train_loss: 1.47349, acc: 0.39263


[INFO 2024-12-03 18:53:07,052] [0] Ed: 32000, train_loss: 1.47349, acc: 0.39263


INFO:root:[0] Ed: 35200, train_loss: 1.47050, acc: 0.39392


[INFO 2024-12-03 18:53:19,227] [0] Ed: 35200, train_loss: 1.47050, acc: 0.39392


INFO:root:[0] Ed: 38400, train_loss: 1.46714, acc: 0.39674


[INFO 2024-12-03 18:53:31,443] [0] Ed: 38400, train_loss: 1.46714, acc: 0.39674


INFO:root:[0] Ed: 41600, train_loss: 1.46267, acc: 0.40065


[INFO 2024-12-03 18:53:43,691] [0] Ed: 41600, train_loss: 1.46267, acc: 0.40065


INFO:root:[0] Ed: 44800, train_loss: 1.45906, acc: 0.40315


[INFO 2024-12-03 18:53:55,997] [0] Ed: 44800, train_loss: 1.45906, acc: 0.40315


INFO:root:[0] Ed: 48000, train_loss: 1.45688, acc: 0.40413


[INFO 2024-12-03 18:54:08,342] [0] Ed: 48000, train_loss: 1.45688, acc: 0.40413


INFO:root:[0] Ed: 51200, train_loss: 1.45402, acc: 0.40498


[INFO 2024-12-03 18:54:20,724] [0] Ed: 51200, train_loss: 1.45402, acc: 0.40498


INFO:root:[0] Ed: 54400, train_loss: 1.45148, acc: 0.40608


[INFO 2024-12-03 18:54:33,144] [0] Ed: 54400, train_loss: 1.45148, acc: 0.40608


INFO:root:[0] Ed: 57600, train_loss: 1.44961, acc: 0.40745


[INFO 2024-12-03 18:54:45,582] [0] Ed: 57600, train_loss: 1.44961, acc: 0.40745


INFO:root:[0] Ed: 60800, train_loss: 1.44770, acc: 0.40831


[INFO 2024-12-03 18:54:58,026] [0] Ed: 60800, train_loss: 1.44770, acc: 0.40831


INFO:root:[0] Ed: 64000, train_loss: 1.44471, acc: 0.40970


[INFO 2024-12-03 18:55:10,500] [0] Ed: 64000, train_loss: 1.44471, acc: 0.40970


INFO:root:[0] Ed: 67200, train_loss: 1.44242, acc: 0.41158


[INFO 2024-12-03 18:55:22,990] [0] Ed: 67200, train_loss: 1.44242, acc: 0.41158


INFO:root:[0] Ed: 70400, train_loss: 1.43999, acc: 0.41256


[INFO 2024-12-03 18:55:35,485] [0] Ed: 70400, train_loss: 1.43999, acc: 0.41256


INFO:root:[0] Ed: 73600, train_loss: 1.43720, acc: 0.41406


[INFO 2024-12-03 18:55:47,979] [0] Ed: 73600, train_loss: 1.43720, acc: 0.41406


INFO:root:[0] Ed: 76800, train_loss: 1.43478, acc: 0.41544


[INFO 2024-12-03 18:56:00,455] [0] Ed: 76800, train_loss: 1.43478, acc: 0.41544


INFO:root:[0] Ed: 80000, train_loss: 1.43370, acc: 0.41586


[INFO 2024-12-03 18:56:12,951] [0] Ed: 80000, train_loss: 1.43370, acc: 0.41586


INFO:root:[0] Ed: 83200, train_loss: 1.43133, acc: 0.41701


[INFO 2024-12-03 18:56:25,456] [0] Ed: 83200, train_loss: 1.43133, acc: 0.41701


INFO:root:[0] Ed: 86400, train_loss: 1.42953, acc: 0.41794


[INFO 2024-12-03 18:56:37,980] [0] Ed: 86400, train_loss: 1.42953, acc: 0.41794


INFO:root:[0] Ed: 89600, train_loss: 1.42746, acc: 0.41905


[INFO 2024-12-03 18:56:50,493] [0] Ed: 89600, train_loss: 1.42746, acc: 0.41905


INFO:root:[0] Ed: 92800, train_loss: 1.42652, acc: 0.41922


[INFO 2024-12-03 18:57:03,034] [0] Ed: 92800, train_loss: 1.42652, acc: 0.41922


INFO:root:[0] Ed: 96000, train_loss: 1.42530, acc: 0.41976


[INFO 2024-12-03 18:57:15,577] [0] Ed: 96000, train_loss: 1.42530, acc: 0.41976


INFO:root:[0] Ed: 99200, train_loss: 1.42407, acc: 0.42053


[INFO 2024-12-03 18:57:28,123] [0] Ed: 99200, train_loss: 1.42407, acc: 0.42053


INFO:root:[0] Ed: 102400, train_loss: 1.42213, acc: 0.42140


[INFO 2024-12-03 18:57:40,730] [0] Ed: 102400, train_loss: 1.42213, acc: 0.42140


INFO:root:[0] Ed: 105600, train_loss: 1.42028, acc: 0.42249


[INFO 2024-12-03 18:57:53,491] [0] Ed: 105600, train_loss: 1.42028, acc: 0.42249


INFO:root:[0] Ed: 108800, train_loss: 1.41910, acc: 0.42287


[INFO 2024-12-03 18:58:06,411] [0] Ed: 108800, train_loss: 1.41910, acc: 0.42287


INFO:root:[0] Ed: 112000, train_loss: 1.41823, acc: 0.42341


[INFO 2024-12-03 18:58:19,238] [0] Ed: 112000, train_loss: 1.41823, acc: 0.42341


INFO:root:[0] Ed: 115200, train_loss: 1.41716, acc: 0.42406


[INFO 2024-12-03 18:58:31,978] [0] Ed: 115200, train_loss: 1.41716, acc: 0.42406


INFO:root:[0] Ed: 118400, train_loss: 1.41580, acc: 0.42470


[INFO 2024-12-03 18:58:44,661] [0] Ed: 118400, train_loss: 1.41580, acc: 0.42470


INFO:root:[0] Ed: 121600, train_loss: 1.41503, acc: 0.42506


[INFO 2024-12-03 18:58:57,354] [0] Ed: 121600, train_loss: 1.41503, acc: 0.42506


INFO:root:[0] Ed: 124800, train_loss: 1.41369, acc: 0.42561


[INFO 2024-12-03 18:59:10,069] [0] Ed: 124800, train_loss: 1.41369, acc: 0.42561


INFO:root:[0] Ed: 128000, train_loss: 1.41300, acc: 0.42625


[INFO 2024-12-03 18:59:22,805] [0] Ed: 128000, train_loss: 1.41300, acc: 0.42625


INFO:root:[0] Ed: 131200, train_loss: 1.41174, acc: 0.42720


[INFO 2024-12-03 18:59:35,581] [0] Ed: 131200, train_loss: 1.41174, acc: 0.42720


INFO:root:[0] Ed: 134400, train_loss: 1.41082, acc: 0.42767


[INFO 2024-12-03 18:59:48,334] [0] Ed: 134400, train_loss: 1.41082, acc: 0.42767


INFO:root:[0] Ed: 137600, train_loss: 1.41041, acc: 0.42803


[INFO 2024-12-03 19:00:01,112] [0] Ed: 137600, train_loss: 1.41041, acc: 0.42803


INFO:root:[0] Ed: 140800, train_loss: 1.40929, acc: 0.42879


[INFO 2024-12-03 19:00:13,848] [0] Ed: 140800, train_loss: 1.40929, acc: 0.42879


INFO:root:[0] Ed: 144000, train_loss: 1.40835, acc: 0.42875


[INFO 2024-12-03 19:00:26,583] [0] Ed: 144000, train_loss: 1.40835, acc: 0.42875


INFO:root:[0] Ed: 147200, train_loss: 1.40725, acc: 0.42950


[INFO 2024-12-03 19:00:39,304] [0] Ed: 147200, train_loss: 1.40725, acc: 0.42950


INFO:root:[0] Ed: 150400, train_loss: 1.40619, acc: 0.42975


[INFO 2024-12-03 19:00:52,010] [0] Ed: 150400, train_loss: 1.40619, acc: 0.42975


INFO:root:[0] Ed: 153600, train_loss: 1.40498, acc: 0.43016


[INFO 2024-12-03 19:01:04,743] [0] Ed: 153600, train_loss: 1.40498, acc: 0.43016


INFO:root:[0] Ed: 156800, train_loss: 1.40408, acc: 0.43063


[INFO 2024-12-03 19:01:17,445] [0] Ed: 156800, train_loss: 1.40408, acc: 0.43063


INFO:root:[0] Ed: 160000, train_loss: 1.40317, acc: 0.43106


[INFO 2024-12-03 19:01:30,157] [0] Ed: 160000, train_loss: 1.40317, acc: 0.43106


INFO:root:[0] Ed: 163200, train_loss: 1.40240, acc: 0.43156


[INFO 2024-12-03 19:01:42,864] [0] Ed: 163200, train_loss: 1.40240, acc: 0.43156


INFO:root:[0] Ed: 166400, train_loss: 1.40151, acc: 0.43190


[INFO 2024-12-03 19:01:55,573] [0] Ed: 166400, train_loss: 1.40151, acc: 0.43190


INFO:root:[0] Ed: 169600, train_loss: 1.40115, acc: 0.43244


[INFO 2024-12-03 19:02:08,284] [0] Ed: 169600, train_loss: 1.40115, acc: 0.43244


INFO:root:[0] Ed: 172800, train_loss: 1.40047, acc: 0.43277


[INFO 2024-12-03 19:02:21,016] [0] Ed: 172800, train_loss: 1.40047, acc: 0.43277


INFO:root:[0] Ed: 176000, train_loss: 1.39963, acc: 0.43326


[INFO 2024-12-03 19:02:33,739] [0] Ed: 176000, train_loss: 1.39963, acc: 0.43326


INFO:root:[0] Ed: 179200, train_loss: 1.39896, acc: 0.43359


[INFO 2024-12-03 19:02:46,475] [0] Ed: 179200, train_loss: 1.39896, acc: 0.43359


INFO:root:[0] Ed: 182400, train_loss: 1.39830, acc: 0.43388


[INFO 2024-12-03 19:02:59,209] [0] Ed: 182400, train_loss: 1.39830, acc: 0.43388


INFO:root:[0] Ed: 185600, train_loss: 1.39744, acc: 0.43430


[INFO 2024-12-03 19:03:11,950] [0] Ed: 185600, train_loss: 1.39744, acc: 0.43430


INFO:root:[0] Ed: 188800, train_loss: 1.39677, acc: 0.43465


[INFO 2024-12-03 19:03:24,674] [0] Ed: 188800, train_loss: 1.39677, acc: 0.43465


INFO:root:[0] Ed: 192000, train_loss: 1.39645, acc: 0.43482


[INFO 2024-12-03 19:03:37,413] [0] Ed: 192000, train_loss: 1.39645, acc: 0.43482


INFO:root:[0] Ed: 195200, train_loss: 1.39565, acc: 0.43505


[INFO 2024-12-03 19:03:50,155] [0] Ed: 195200, train_loss: 1.39565, acc: 0.43505


INFO:root:[0] Ed: 198400, train_loss: 1.39517, acc: 0.43547


[INFO 2024-12-03 19:04:02,889] [0] Ed: 198400, train_loss: 1.39517, acc: 0.43547


INFO:root:[0] Ed: 201600, train_loss: 1.39457, acc: 0.43580


[INFO 2024-12-03 19:04:15,630] [0] Ed: 201600, train_loss: 1.39457, acc: 0.43580


INFO:root:[0] Ed: 204800, train_loss: 1.39425, acc: 0.43596


[INFO 2024-12-03 19:04:28,360] [0] Ed: 204800, train_loss: 1.39425, acc: 0.43596


INFO:root:[0] Ed: 208000, train_loss: 1.39378, acc: 0.43618


[INFO 2024-12-03 19:04:41,090] [0] Ed: 208000, train_loss: 1.39378, acc: 0.43618


INFO:root:[0] Ed: 211200, train_loss: 1.39328, acc: 0.43636


[INFO 2024-12-03 19:04:53,824] [0] Ed: 211200, train_loss: 1.39328, acc: 0.43636


INFO:root:[0] Ed: 214400, train_loss: 1.39254, acc: 0.43669


[INFO 2024-12-03 19:05:06,571] [0] Ed: 214400, train_loss: 1.39254, acc: 0.43669


INFO:root:[0] Ed: 217600, train_loss: 1.39215, acc: 0.43673


[INFO 2024-12-03 19:05:19,316] [0] Ed: 217600, train_loss: 1.39215, acc: 0.43673


INFO:root:[0] Ed: 220800, train_loss: 1.39149, acc: 0.43693


[INFO 2024-12-03 19:05:32,047] [0] Ed: 220800, train_loss: 1.39149, acc: 0.43693


INFO:root:[0] Ed: 224000, train_loss: 1.39077, acc: 0.43712


[INFO 2024-12-03 19:05:44,780] [0] Ed: 224000, train_loss: 1.39077, acc: 0.43712


INFO:root:[0] Ed: 227200, train_loss: 1.39004, acc: 0.43754


[INFO 2024-12-03 19:05:57,509] [0] Ed: 227200, train_loss: 1.39004, acc: 0.43754


INFO:root:[0] Ed: 230400, train_loss: 1.38935, acc: 0.43784


[INFO 2024-12-03 19:06:10,233] [0] Ed: 230400, train_loss: 1.38935, acc: 0.43784


INFO:root:[0] Ed: 233600, train_loss: 1.38830, acc: 0.43837


[INFO 2024-12-03 19:06:22,962] [0] Ed: 233600, train_loss: 1.38830, acc: 0.43837


INFO:root:Training finish.


[INFO 2024-12-03 19:06:33,541] Training finish.


INFO:root:Model saved to /content/model/epoch-1.pt.


[INFO 2024-12-03 19:06:33,810] Model saved to /content/model/epoch-1.pt.


INFO:root:[0] Ed: 0, train_loss: inf, acc: inf


[INFO 2024-12-03 19:06:33,951] [0] Ed: 0, train_loss: inf, acc: inf


INFO:root:[0] Ed: 3200, train_loss: 1.35951, acc: 0.45719


[INFO 2024-12-03 19:06:46,693] [0] Ed: 3200, train_loss: 1.35951, acc: 0.45719


INFO:root:[0] Ed: 6400, train_loss: 1.34843, acc: 0.45859


[INFO 2024-12-03 19:06:59,421] [0] Ed: 6400, train_loss: 1.34843, acc: 0.45859


INFO:root:[0] Ed: 9600, train_loss: 1.34555, acc: 0.46146


[INFO 2024-12-03 19:07:12,148] [0] Ed: 9600, train_loss: 1.34555, acc: 0.46146


INFO:root:[0] Ed: 12800, train_loss: 1.33953, acc: 0.46453


[INFO 2024-12-03 19:07:24,874] [0] Ed: 12800, train_loss: 1.33953, acc: 0.46453


INFO:root:[0] Ed: 16000, train_loss: 1.34337, acc: 0.46188


[INFO 2024-12-03 19:07:37,623] [0] Ed: 16000, train_loss: 1.34337, acc: 0.46188


INFO:root:[0] Ed: 19200, train_loss: 1.34380, acc: 0.46297


[INFO 2024-12-03 19:07:50,349] [0] Ed: 19200, train_loss: 1.34380, acc: 0.46297


INFO:root:[0] Ed: 22400, train_loss: 1.34181, acc: 0.46384


[INFO 2024-12-03 19:08:03,079] [0] Ed: 22400, train_loss: 1.34181, acc: 0.46384


INFO:root:[0] Ed: 25600, train_loss: 1.34052, acc: 0.46406


[INFO 2024-12-03 19:08:15,816] [0] Ed: 25600, train_loss: 1.34052, acc: 0.46406


INFO:root:[0] Ed: 28800, train_loss: 1.33927, acc: 0.46392


[INFO 2024-12-03 19:08:28,543] [0] Ed: 28800, train_loss: 1.33927, acc: 0.46392


INFO:root:[0] Ed: 32000, train_loss: 1.33915, acc: 0.46406


[INFO 2024-12-03 19:08:41,285] [0] Ed: 32000, train_loss: 1.33915, acc: 0.46406


INFO:root:[0] Ed: 35200, train_loss: 1.34004, acc: 0.46378


[INFO 2024-12-03 19:08:54,029] [0] Ed: 35200, train_loss: 1.34004, acc: 0.46378


INFO:root:[0] Ed: 38400, train_loss: 1.33865, acc: 0.46391


[INFO 2024-12-03 19:09:06,764] [0] Ed: 38400, train_loss: 1.33865, acc: 0.46391


INFO:root:[0] Ed: 41600, train_loss: 1.33623, acc: 0.46582


[INFO 2024-12-03 19:09:19,507] [0] Ed: 41600, train_loss: 1.33623, acc: 0.46582


INFO:root:[0] Ed: 44800, train_loss: 1.33576, acc: 0.46594


[INFO 2024-12-03 19:09:32,243] [0] Ed: 44800, train_loss: 1.33576, acc: 0.46594


INFO:root:[0] Ed: 48000, train_loss: 1.33620, acc: 0.46521


[INFO 2024-12-03 19:09:44,972] [0] Ed: 48000, train_loss: 1.33620, acc: 0.46521


INFO:root:[0] Ed: 51200, train_loss: 1.33522, acc: 0.46551


[INFO 2024-12-03 19:09:57,709] [0] Ed: 51200, train_loss: 1.33522, acc: 0.46551


INFO:root:[0] Ed: 54400, train_loss: 1.33476, acc: 0.46590


[INFO 2024-12-03 19:10:10,432] [0] Ed: 54400, train_loss: 1.33476, acc: 0.46590


INFO:root:[0] Ed: 57600, train_loss: 1.33441, acc: 0.46625


[INFO 2024-12-03 19:10:23,136] [0] Ed: 57600, train_loss: 1.33441, acc: 0.46625


INFO:root:[0] Ed: 60800, train_loss: 1.33439, acc: 0.46622


[INFO 2024-12-03 19:10:35,852] [0] Ed: 60800, train_loss: 1.33439, acc: 0.46622


INFO:root:[0] Ed: 64000, train_loss: 1.33328, acc: 0.46655


[INFO 2024-12-03 19:10:48,607] [0] Ed: 64000, train_loss: 1.33328, acc: 0.46655


INFO:root:[0] Ed: 67200, train_loss: 1.33248, acc: 0.46704


[INFO 2024-12-03 19:11:01,390] [0] Ed: 67200, train_loss: 1.33248, acc: 0.46704


INFO:root:[0] Ed: 70400, train_loss: 1.33151, acc: 0.46764


[INFO 2024-12-03 19:11:14,158] [0] Ed: 70400, train_loss: 1.33151, acc: 0.46764


INFO:root:[0] Ed: 73600, train_loss: 1.33064, acc: 0.46745


[INFO 2024-12-03 19:11:26,882] [0] Ed: 73600, train_loss: 1.33064, acc: 0.46745


INFO:root:[0] Ed: 76800, train_loss: 1.32949, acc: 0.46801


[INFO 2024-12-03 19:11:39,558] [0] Ed: 76800, train_loss: 1.32949, acc: 0.46801


INFO:root:[0] Ed: 80000, train_loss: 1.32976, acc: 0.46776


[INFO 2024-12-03 19:11:52,264] [0] Ed: 80000, train_loss: 1.32976, acc: 0.46776


INFO:root:[0] Ed: 83200, train_loss: 1.32847, acc: 0.46821


[INFO 2024-12-03 19:12:04,981] [0] Ed: 83200, train_loss: 1.32847, acc: 0.46821


INFO:root:[0] Ed: 86400, train_loss: 1.32809, acc: 0.46838


[INFO 2024-12-03 19:12:17,763] [0] Ed: 86400, train_loss: 1.32809, acc: 0.46838


INFO:root:[0] Ed: 89600, train_loss: 1.32719, acc: 0.46868


[INFO 2024-12-03 19:12:30,534] [0] Ed: 89600, train_loss: 1.32719, acc: 0.46868


INFO:root:[0] Ed: 92800, train_loss: 1.32750, acc: 0.46806


[INFO 2024-12-03 19:12:43,291] [0] Ed: 92800, train_loss: 1.32750, acc: 0.46806


INFO:root:[0] Ed: 96000, train_loss: 1.32708, acc: 0.46816


[INFO 2024-12-03 19:12:56,023] [0] Ed: 96000, train_loss: 1.32708, acc: 0.46816


INFO:root:[0] Ed: 99200, train_loss: 1.32721, acc: 0.46796


[INFO 2024-12-03 19:13:08,720] [0] Ed: 99200, train_loss: 1.32721, acc: 0.46796


INFO:root:[0] Ed: 102400, train_loss: 1.32626, acc: 0.46823


[INFO 2024-12-03 19:13:21,445] [0] Ed: 102400, train_loss: 1.32626, acc: 0.46823


INFO:root:[0] Ed: 105600, train_loss: 1.32520, acc: 0.46865


[INFO 2024-12-03 19:13:34,196] [0] Ed: 105600, train_loss: 1.32520, acc: 0.46865


INFO:root:[0] Ed: 108800, train_loss: 1.32533, acc: 0.46876


[INFO 2024-12-03 19:13:46,979] [0] Ed: 108800, train_loss: 1.32533, acc: 0.46876


INFO:root:[0] Ed: 112000, train_loss: 1.32543, acc: 0.46850


[INFO 2024-12-03 19:13:59,734] [0] Ed: 112000, train_loss: 1.32543, acc: 0.46850


INFO:root:[0] Ed: 115200, train_loss: 1.32513, acc: 0.46872


[INFO 2024-12-03 19:14:12,472] [0] Ed: 115200, train_loss: 1.32513, acc: 0.46872


INFO:root:[0] Ed: 118400, train_loss: 1.32450, acc: 0.46895


[INFO 2024-12-03 19:14:25,193] [0] Ed: 118400, train_loss: 1.32450, acc: 0.46895


INFO:root:[0] Ed: 121600, train_loss: 1.32435, acc: 0.46917


[INFO 2024-12-03 19:14:37,924] [0] Ed: 121600, train_loss: 1.32435, acc: 0.46917


INFO:root:[0] Ed: 124800, train_loss: 1.32395, acc: 0.46917


[INFO 2024-12-03 19:14:50,641] [0] Ed: 124800, train_loss: 1.32395, acc: 0.46917


INFO:root:[0] Ed: 128000, train_loss: 1.32414, acc: 0.46913


[INFO 2024-12-03 19:15:03,339] [0] Ed: 128000, train_loss: 1.32414, acc: 0.46913


INFO:root:[0] Ed: 131200, train_loss: 1.32360, acc: 0.46929


[INFO 2024-12-03 19:15:16,060] [0] Ed: 131200, train_loss: 1.32360, acc: 0.46929


INFO:root:[0] Ed: 134400, train_loss: 1.32341, acc: 0.46946


[INFO 2024-12-03 19:15:28,769] [0] Ed: 134400, train_loss: 1.32341, acc: 0.46946


INFO:root:[0] Ed: 137600, train_loss: 1.32378, acc: 0.46945


[INFO 2024-12-03 19:15:41,519] [0] Ed: 137600, train_loss: 1.32378, acc: 0.46945


INFO:root:[0] Ed: 140800, train_loss: 1.32326, acc: 0.46966


[INFO 2024-12-03 19:15:54,257] [0] Ed: 140800, train_loss: 1.32326, acc: 0.46966


INFO:root:[0] Ed: 144000, train_loss: 1.32306, acc: 0.46933


[INFO 2024-12-03 19:16:07,001] [0] Ed: 144000, train_loss: 1.32306, acc: 0.46933


INFO:root:[0] Ed: 147200, train_loss: 1.32275, acc: 0.46950


[INFO 2024-12-03 19:16:19,763] [0] Ed: 147200, train_loss: 1.32275, acc: 0.46950


INFO:root:[0] Ed: 150400, train_loss: 1.32240, acc: 0.46951


[INFO 2024-12-03 19:16:32,527] [0] Ed: 150400, train_loss: 1.32240, acc: 0.46951


INFO:root:[0] Ed: 153600, train_loss: 1.32174, acc: 0.46956


[INFO 2024-12-03 19:16:45,270] [0] Ed: 153600, train_loss: 1.32174, acc: 0.46956


INFO:root:[0] Ed: 156800, train_loss: 1.32140, acc: 0.46965


[INFO 2024-12-03 19:16:58,008] [0] Ed: 156800, train_loss: 1.32140, acc: 0.46965


INFO:root:[0] Ed: 160000, train_loss: 1.32103, acc: 0.46986


[INFO 2024-12-03 19:17:10,735] [0] Ed: 160000, train_loss: 1.32103, acc: 0.46986


INFO:root:[0] Ed: 163200, train_loss: 1.32082, acc: 0.46996


[INFO 2024-12-03 19:17:23,465] [0] Ed: 163200, train_loss: 1.32082, acc: 0.46996


INFO:root:[0] Ed: 166400, train_loss: 1.32064, acc: 0.47002


[INFO 2024-12-03 19:17:36,199] [0] Ed: 166400, train_loss: 1.32064, acc: 0.47002


INFO:root:[0] Ed: 169600, train_loss: 1.32071, acc: 0.47013


[INFO 2024-12-03 19:17:48,930] [0] Ed: 169600, train_loss: 1.32071, acc: 0.47013


INFO:root:[0] Ed: 172800, train_loss: 1.32055, acc: 0.47019


[INFO 2024-12-03 19:18:01,665] [0] Ed: 172800, train_loss: 1.32055, acc: 0.47019


INFO:root:[0] Ed: 176000, train_loss: 1.32026, acc: 0.47028


[INFO 2024-12-03 19:18:14,369] [0] Ed: 176000, train_loss: 1.32026, acc: 0.47028


INFO:root:[0] Ed: 179200, train_loss: 1.32014, acc: 0.47034


[INFO 2024-12-03 19:18:27,095] [0] Ed: 179200, train_loss: 1.32014, acc: 0.47034


INFO:root:[0] Ed: 182400, train_loss: 1.31994, acc: 0.47042


[INFO 2024-12-03 19:18:39,818] [0] Ed: 182400, train_loss: 1.31994, acc: 0.47042


INFO:root:[0] Ed: 185600, train_loss: 1.31945, acc: 0.47055


[INFO 2024-12-03 19:18:52,551] [0] Ed: 185600, train_loss: 1.31945, acc: 0.47055


INFO:root:[0] Ed: 188800, train_loss: 1.31940, acc: 0.47077


[INFO 2024-12-03 19:19:05,263] [0] Ed: 188800, train_loss: 1.31940, acc: 0.47077


INFO:root:[0] Ed: 192000, train_loss: 1.31956, acc: 0.47068


[INFO 2024-12-03 19:19:17,980] [0] Ed: 192000, train_loss: 1.31956, acc: 0.47068


INFO:root:[0] Ed: 195200, train_loss: 1.31929, acc: 0.47070


[INFO 2024-12-03 19:19:30,706] [0] Ed: 195200, train_loss: 1.31929, acc: 0.47070


INFO:root:[0] Ed: 198400, train_loss: 1.31930, acc: 0.47072


[INFO 2024-12-03 19:19:43,434] [0] Ed: 198400, train_loss: 1.31930, acc: 0.47072


INFO:root:[0] Ed: 201600, train_loss: 1.31898, acc: 0.47075


[INFO 2024-12-03 19:19:56,164] [0] Ed: 201600, train_loss: 1.31898, acc: 0.47075


INFO:root:[0] Ed: 204800, train_loss: 1.31925, acc: 0.47061


[INFO 2024-12-03 19:20:08,877] [0] Ed: 204800, train_loss: 1.31925, acc: 0.47061


INFO:root:[0] Ed: 208000, train_loss: 1.31933, acc: 0.47060


[INFO 2024-12-03 19:20:21,593] [0] Ed: 208000, train_loss: 1.31933, acc: 0.47060


INFO:root:[0] Ed: 211200, train_loss: 1.31930, acc: 0.47060


[INFO 2024-12-03 19:20:34,322] [0] Ed: 211200, train_loss: 1.31930, acc: 0.47060


INFO:root:[0] Ed: 214400, train_loss: 1.31906, acc: 0.47082


[INFO 2024-12-03 19:20:47,080] [0] Ed: 214400, train_loss: 1.31906, acc: 0.47082


INFO:root:[0] Ed: 217600, train_loss: 1.31898, acc: 0.47076


[INFO 2024-12-03 19:20:59,813] [0] Ed: 217600, train_loss: 1.31898, acc: 0.47076


INFO:root:[0] Ed: 220800, train_loss: 1.31873, acc: 0.47080


[INFO 2024-12-03 19:21:12,559] [0] Ed: 220800, train_loss: 1.31873, acc: 0.47080


INFO:root:[0] Ed: 224000, train_loss: 1.31826, acc: 0.47098


[INFO 2024-12-03 19:21:25,279] [0] Ed: 224000, train_loss: 1.31826, acc: 0.47098


INFO:root:[0] Ed: 227200, train_loss: 1.31784, acc: 0.47119


[INFO 2024-12-03 19:21:37,984] [0] Ed: 227200, train_loss: 1.31784, acc: 0.47119


INFO:root:[0] Ed: 230400, train_loss: 1.31762, acc: 0.47133


[INFO 2024-12-03 19:21:50,703] [0] Ed: 230400, train_loss: 1.31762, acc: 0.47133


INFO:root:[0] Ed: 233600, train_loss: 1.31696, acc: 0.47166


[INFO 2024-12-03 19:22:03,429] [0] Ed: 233600, train_loss: 1.31696, acc: 0.47166


INFO:root:Training finish.


[INFO 2024-12-03 19:22:14,006] Training finish.


INFO:root:Model saved to /content/model/epoch-2.pt.


[INFO 2024-12-03 19:22:14,258] Model saved to /content/model/epoch-2.pt.


INFO:root:[0] Ed: 0, train_loss: inf, acc: inf


[INFO 2024-12-03 19:22:14,399] [0] Ed: 0, train_loss: inf, acc: inf


INFO:root:[0] Ed: 3200, train_loss: 1.32223, acc: 0.47875


[INFO 2024-12-03 19:22:27,147] [0] Ed: 3200, train_loss: 1.32223, acc: 0.47875


INFO:root:[0] Ed: 6400, train_loss: 1.30358, acc: 0.48594


[INFO 2024-12-03 19:22:39,854] [0] Ed: 6400, train_loss: 1.30358, acc: 0.48594


INFO:root:[0] Ed: 9600, train_loss: 1.30425, acc: 0.48344


[INFO 2024-12-03 19:22:52,578] [0] Ed: 9600, train_loss: 1.30425, acc: 0.48344


INFO:root:[0] Ed: 12800, train_loss: 1.29758, acc: 0.48289


[INFO 2024-12-03 19:23:05,311] [0] Ed: 12800, train_loss: 1.29758, acc: 0.48289


INFO:root:[0] Ed: 16000, train_loss: 1.30110, acc: 0.47969


[INFO 2024-12-03 19:23:18,076] [0] Ed: 16000, train_loss: 1.30110, acc: 0.47969


INFO:root:[0] Ed: 19200, train_loss: 1.30165, acc: 0.48005


[INFO 2024-12-03 19:23:30,837] [0] Ed: 19200, train_loss: 1.30165, acc: 0.48005


INFO:root:[0] Ed: 22400, train_loss: 1.29992, acc: 0.48103


[INFO 2024-12-03 19:23:43,605] [0] Ed: 22400, train_loss: 1.29992, acc: 0.48103


INFO:root:[0] Ed: 25600, train_loss: 1.29884, acc: 0.48219


[INFO 2024-12-03 19:23:56,358] [0] Ed: 25600, train_loss: 1.29884, acc: 0.48219


INFO:root:[0] Ed: 28800, train_loss: 1.29823, acc: 0.48247


[INFO 2024-12-03 19:24:09,107] [0] Ed: 28800, train_loss: 1.29823, acc: 0.48247


INFO:root:[0] Ed: 32000, train_loss: 1.29871, acc: 0.48153


[INFO 2024-12-03 19:24:21,887] [0] Ed: 32000, train_loss: 1.29871, acc: 0.48153


INFO:root:[0] Ed: 35200, train_loss: 1.29950, acc: 0.48099


[INFO 2024-12-03 19:24:34,664] [0] Ed: 35200, train_loss: 1.29950, acc: 0.48099


INFO:root:[0] Ed: 38400, train_loss: 1.29853, acc: 0.48188


[INFO 2024-12-03 19:24:47,435] [0] Ed: 38400, train_loss: 1.29853, acc: 0.48188


INFO:root:[0] Ed: 41600, train_loss: 1.29521, acc: 0.48337


[INFO 2024-12-03 19:25:00,181] [0] Ed: 41600, train_loss: 1.29521, acc: 0.48337


INFO:root:[0] Ed: 44800, train_loss: 1.29483, acc: 0.48348


[INFO 2024-12-03 19:25:12,908] [0] Ed: 44800, train_loss: 1.29483, acc: 0.48348


INFO:root:[0] Ed: 48000, train_loss: 1.29554, acc: 0.48271


[INFO 2024-12-03 19:25:25,634] [0] Ed: 48000, train_loss: 1.29554, acc: 0.48271


INFO:root:[0] Ed: 51200, train_loss: 1.29461, acc: 0.48352


[INFO 2024-12-03 19:25:38,362] [0] Ed: 51200, train_loss: 1.29461, acc: 0.48352


INFO:root:[0] Ed: 54400, train_loss: 1.29389, acc: 0.48397


[INFO 2024-12-03 19:25:51,095] [0] Ed: 54400, train_loss: 1.29389, acc: 0.48397


INFO:root:[0] Ed: 57600, train_loss: 1.29347, acc: 0.48413


[INFO 2024-12-03 19:26:03,831] [0] Ed: 57600, train_loss: 1.29347, acc: 0.48413


INFO:root:[0] Ed: 60800, train_loss: 1.29372, acc: 0.48414


[INFO 2024-12-03 19:26:16,524] [0] Ed: 60800, train_loss: 1.29372, acc: 0.48414


INFO:root:[0] Ed: 64000, train_loss: 1.29278, acc: 0.48480


[INFO 2024-12-03 19:26:29,256] [0] Ed: 64000, train_loss: 1.29278, acc: 0.48480


INFO:root:[0] Ed: 67200, train_loss: 1.29186, acc: 0.48549


[INFO 2024-12-03 19:26:42,036] [0] Ed: 67200, train_loss: 1.29186, acc: 0.48549


INFO:root:[0] Ed: 70400, train_loss: 1.29090, acc: 0.48577


[INFO 2024-12-03 19:26:54,806] [0] Ed: 70400, train_loss: 1.29090, acc: 0.48577


INFO:root:[0] Ed: 73600, train_loss: 1.29041, acc: 0.48565


[INFO 2024-12-03 19:27:07,559] [0] Ed: 73600, train_loss: 1.29041, acc: 0.48565


INFO:root:[0] Ed: 76800, train_loss: 1.28911, acc: 0.48615


[INFO 2024-12-03 19:27:20,295] [0] Ed: 76800, train_loss: 1.28911, acc: 0.48615


INFO:root:[0] Ed: 80000, train_loss: 1.28944, acc: 0.48574


[INFO 2024-12-03 19:27:33,007] [0] Ed: 80000, train_loss: 1.28944, acc: 0.48574


INFO:root:[0] Ed: 83200, train_loss: 1.28838, acc: 0.48631


[INFO 2024-12-03 19:27:45,722] [0] Ed: 83200, train_loss: 1.28838, acc: 0.48631


INFO:root:[0] Ed: 86400, train_loss: 1.28807, acc: 0.48631


[INFO 2024-12-03 19:27:58,453] [0] Ed: 86400, train_loss: 1.28807, acc: 0.48631


INFO:root:[0] Ed: 89600, train_loss: 1.28691, acc: 0.48696


[INFO 2024-12-03 19:28:11,148] [0] Ed: 89600, train_loss: 1.28691, acc: 0.48696


INFO:root:[0] Ed: 92800, train_loss: 1.28756, acc: 0.48626


[INFO 2024-12-03 19:28:23,857] [0] Ed: 92800, train_loss: 1.28756, acc: 0.48626


INFO:root:[0] Ed: 96000, train_loss: 1.28749, acc: 0.48609


[INFO 2024-12-03 19:28:36,566] [0] Ed: 96000, train_loss: 1.28749, acc: 0.48609


INFO:root:[0] Ed: 99200, train_loss: 1.28759, acc: 0.48588


[INFO 2024-12-03 19:28:49,256] [0] Ed: 99200, train_loss: 1.28759, acc: 0.48588


INFO:root:[0] Ed: 102400, train_loss: 1.28659, acc: 0.48602


[INFO 2024-12-03 19:29:02,001] [0] Ed: 102400, train_loss: 1.28659, acc: 0.48602


INFO:root:[0] Ed: 105600, train_loss: 1.28521, acc: 0.48664


[INFO 2024-12-03 19:29:14,757] [0] Ed: 105600, train_loss: 1.28521, acc: 0.48664


INFO:root:[0] Ed: 108800, train_loss: 1.28537, acc: 0.48648


[INFO 2024-12-03 19:29:27,509] [0] Ed: 108800, train_loss: 1.28537, acc: 0.48648


INFO:root:[0] Ed: 112000, train_loss: 1.28570, acc: 0.48636


[INFO 2024-12-03 19:29:40,242] [0] Ed: 112000, train_loss: 1.28570, acc: 0.48636


INFO:root:[0] Ed: 115200, train_loss: 1.28539, acc: 0.48656


[INFO 2024-12-03 19:29:52,973] [0] Ed: 115200, train_loss: 1.28539, acc: 0.48656


INFO:root:[0] Ed: 118400, train_loss: 1.28452, acc: 0.48700


[INFO 2024-12-03 19:30:05,720] [0] Ed: 118400, train_loss: 1.28452, acc: 0.48700


INFO:root:[0] Ed: 121600, train_loss: 1.28434, acc: 0.48702


[INFO 2024-12-03 19:30:18,502] [0] Ed: 121600, train_loss: 1.28434, acc: 0.48702


INFO:root:[0] Ed: 124800, train_loss: 1.28408, acc: 0.48697


[INFO 2024-12-03 19:30:31,262] [0] Ed: 124800, train_loss: 1.28408, acc: 0.48697


INFO:root:[0] Ed: 128000, train_loss: 1.28434, acc: 0.48699


[INFO 2024-12-03 19:30:44,002] [0] Ed: 128000, train_loss: 1.28434, acc: 0.48699


INFO:root:[0] Ed: 131200, train_loss: 1.28367, acc: 0.48732


[INFO 2024-12-03 19:30:56,777] [0] Ed: 131200, train_loss: 1.28367, acc: 0.48732


INFO:root:[0] Ed: 134400, train_loss: 1.28351, acc: 0.48714


[INFO 2024-12-03 19:31:09,510] [0] Ed: 134400, train_loss: 1.28351, acc: 0.48714


INFO:root:[0] Ed: 137600, train_loss: 1.28378, acc: 0.48709


[INFO 2024-12-03 19:31:22,273] [0] Ed: 137600, train_loss: 1.28378, acc: 0.48709


INFO:root:[0] Ed: 140800, train_loss: 1.28348, acc: 0.48722


[INFO 2024-12-03 19:31:35,003] [0] Ed: 140800, train_loss: 1.28348, acc: 0.48722


INFO:root:[0] Ed: 144000, train_loss: 1.28349, acc: 0.48672


[INFO 2024-12-03 19:31:47,736] [0] Ed: 144000, train_loss: 1.28349, acc: 0.48672


INFO:root:[0] Ed: 147200, train_loss: 1.28317, acc: 0.48696


[INFO 2024-12-03 19:32:00,453] [0] Ed: 147200, train_loss: 1.28317, acc: 0.48696


INFO:root:[0] Ed: 150400, train_loss: 1.28269, acc: 0.48713


[INFO 2024-12-03 19:32:13,193] [0] Ed: 150400, train_loss: 1.28269, acc: 0.48713


INFO:root:[0] Ed: 153600, train_loss: 1.28220, acc: 0.48709


[INFO 2024-12-03 19:32:25,933] [0] Ed: 153600, train_loss: 1.28220, acc: 0.48709


INFO:root:[0] Ed: 156800, train_loss: 1.28176, acc: 0.48714


[INFO 2024-12-03 19:32:38,663] [0] Ed: 156800, train_loss: 1.28176, acc: 0.48714


INFO:root:[0] Ed: 160000, train_loss: 1.28139, acc: 0.48736


[INFO 2024-12-03 19:32:51,389] [0] Ed: 160000, train_loss: 1.28139, acc: 0.48736


INFO:root:[0] Ed: 163200, train_loss: 1.28125, acc: 0.48756


[INFO 2024-12-03 19:33:04,116] [0] Ed: 163200, train_loss: 1.28125, acc: 0.48756


INFO:root:[0] Ed: 166400, train_loss: 1.28110, acc: 0.48778


[INFO 2024-12-03 19:33:16,845] [0] Ed: 166400, train_loss: 1.28110, acc: 0.48778


INFO:root:[0] Ed: 169600, train_loss: 1.28118, acc: 0.48788


[INFO 2024-12-03 19:33:29,580] [0] Ed: 169600, train_loss: 1.28118, acc: 0.48788


INFO:root:[0] Ed: 172800, train_loss: 1.28108, acc: 0.48786


[INFO 2024-12-03 19:33:42,316] [0] Ed: 172800, train_loss: 1.28108, acc: 0.48786


INFO:root:[0] Ed: 176000, train_loss: 1.28085, acc: 0.48795


[INFO 2024-12-03 19:33:55,046] [0] Ed: 176000, train_loss: 1.28085, acc: 0.48795


INFO:root:[0] Ed: 179200, train_loss: 1.28066, acc: 0.48813


[INFO 2024-12-03 19:34:07,794] [0] Ed: 179200, train_loss: 1.28066, acc: 0.48813


INFO:root:[0] Ed: 182400, train_loss: 1.28050, acc: 0.48818


[INFO 2024-12-03 19:34:20,511] [0] Ed: 182400, train_loss: 1.28050, acc: 0.48818


INFO:root:[0] Ed: 185600, train_loss: 1.28024, acc: 0.48814


[INFO 2024-12-03 19:34:33,250] [0] Ed: 185600, train_loss: 1.28024, acc: 0.48814


INFO:root:[0] Ed: 188800, train_loss: 1.28021, acc: 0.48831


[INFO 2024-12-03 19:34:45,989] [0] Ed: 188800, train_loss: 1.28021, acc: 0.48831


INFO:root:[0] Ed: 192000, train_loss: 1.28037, acc: 0.48817


[INFO 2024-12-03 19:34:58,720] [0] Ed: 192000, train_loss: 1.28037, acc: 0.48817


INFO:root:[0] Ed: 195200, train_loss: 1.28009, acc: 0.48820


[INFO 2024-12-03 19:35:11,465] [0] Ed: 195200, train_loss: 1.28009, acc: 0.48820


INFO:root:[0] Ed: 198400, train_loss: 1.28013, acc: 0.48838


[INFO 2024-12-03 19:35:24,206] [0] Ed: 198400, train_loss: 1.28013, acc: 0.48838


INFO:root:[0] Ed: 201600, train_loss: 1.27976, acc: 0.48861


[INFO 2024-12-03 19:35:36,930] [0] Ed: 201600, train_loss: 1.27976, acc: 0.48861


INFO:root:[0] Ed: 204800, train_loss: 1.28003, acc: 0.48832


[INFO 2024-12-03 19:35:49,670] [0] Ed: 204800, train_loss: 1.28003, acc: 0.48832


INFO:root:[0] Ed: 208000, train_loss: 1.28016, acc: 0.48838


[INFO 2024-12-03 19:36:02,407] [0] Ed: 208000, train_loss: 1.28016, acc: 0.48838


INFO:root:[0] Ed: 211200, train_loss: 1.28003, acc: 0.48844


[INFO 2024-12-03 19:36:15,148] [0] Ed: 211200, train_loss: 1.28003, acc: 0.48844


INFO:root:[0] Ed: 214400, train_loss: 1.27993, acc: 0.48859


[INFO 2024-12-03 19:36:27,905] [0] Ed: 214400, train_loss: 1.27993, acc: 0.48859


INFO:root:[0] Ed: 217600, train_loss: 1.27992, acc: 0.48845


[INFO 2024-12-03 19:36:40,652] [0] Ed: 217600, train_loss: 1.27992, acc: 0.48845


INFO:root:[0] Ed: 220800, train_loss: 1.27961, acc: 0.48857


[INFO 2024-12-03 19:36:53,374] [0] Ed: 220800, train_loss: 1.27961, acc: 0.48857


INFO:root:[0] Ed: 224000, train_loss: 1.27919, acc: 0.48870


[INFO 2024-12-03 19:37:06,114] [0] Ed: 224000, train_loss: 1.27919, acc: 0.48870


INFO:root:[0] Ed: 227200, train_loss: 1.27880, acc: 0.48895


[INFO 2024-12-03 19:37:18,848] [0] Ed: 227200, train_loss: 1.27880, acc: 0.48895


INFO:root:[0] Ed: 230400, train_loss: 1.27857, acc: 0.48917


[INFO 2024-12-03 19:37:31,575] [0] Ed: 230400, train_loss: 1.27857, acc: 0.48917


INFO:root:[0] Ed: 233600, train_loss: 1.27787, acc: 0.48952


[INFO 2024-12-03 19:37:44,302] [0] Ed: 233600, train_loss: 1.27787, acc: 0.48952


INFO:root:Training finish.


[INFO 2024-12-03 19:37:54,883] Training finish.


INFO:root:Model saved to /content/model/epoch-3.pt.


[INFO 2024-12-03 19:37:55,150] Model saved to /content/model/epoch-3.pt.


INFO:root:[0] Ed: 0, train_loss: inf, acc: inf


[INFO 2024-12-03 19:37:55,299] [0] Ed: 0, train_loss: inf, acc: inf


INFO:root:[0] Ed: 3200, train_loss: 1.28273, acc: 0.50500


[INFO 2024-12-03 19:38:08,032] [0] Ed: 3200, train_loss: 1.28273, acc: 0.50500


INFO:root:[0] Ed: 6400, train_loss: 1.26371, acc: 0.50250


[INFO 2024-12-03 19:38:20,771] [0] Ed: 6400, train_loss: 1.26371, acc: 0.50250


INFO:root:[0] Ed: 9600, train_loss: 1.26487, acc: 0.50167


[INFO 2024-12-03 19:38:33,511] [0] Ed: 9600, train_loss: 1.26487, acc: 0.50167


INFO:root:[0] Ed: 12800, train_loss: 1.25849, acc: 0.50219


[INFO 2024-12-03 19:38:46,249] [0] Ed: 12800, train_loss: 1.25849, acc: 0.50219


INFO:root:[0] Ed: 16000, train_loss: 1.26408, acc: 0.49763


[INFO 2024-12-03 19:38:59,013] [0] Ed: 16000, train_loss: 1.26408, acc: 0.49763


INFO:root:[0] Ed: 19200, train_loss: 1.26555, acc: 0.49755


[INFO 2024-12-03 19:39:11,749] [0] Ed: 19200, train_loss: 1.26555, acc: 0.49755


INFO:root:[0] Ed: 22400, train_loss: 1.26411, acc: 0.49835


[INFO 2024-12-03 19:39:24,485] [0] Ed: 22400, train_loss: 1.26411, acc: 0.49835


INFO:root:[0] Ed: 25600, train_loss: 1.26313, acc: 0.49871


[INFO 2024-12-03 19:39:37,233] [0] Ed: 25600, train_loss: 1.26313, acc: 0.49871


INFO:root:[0] Ed: 28800, train_loss: 1.26227, acc: 0.49878


[INFO 2024-12-03 19:39:49,974] [0] Ed: 28800, train_loss: 1.26227, acc: 0.49878


INFO:root:[0] Ed: 32000, train_loss: 1.26257, acc: 0.49825


[INFO 2024-12-03 19:40:02,739] [0] Ed: 32000, train_loss: 1.26257, acc: 0.49825


INFO:root:[0] Ed: 35200, train_loss: 1.26331, acc: 0.49764


[INFO 2024-12-03 19:40:15,492] [0] Ed: 35200, train_loss: 1.26331, acc: 0.49764


INFO:root:[0] Ed: 38400, train_loss: 1.26242, acc: 0.49799


[INFO 2024-12-03 19:40:28,228] [0] Ed: 38400, train_loss: 1.26242, acc: 0.49799


INFO:root:[0] Ed: 41600, train_loss: 1.25975, acc: 0.49873


[INFO 2024-12-03 19:40:40,966] [0] Ed: 41600, train_loss: 1.25975, acc: 0.49873


INFO:root:[0] Ed: 44800, train_loss: 1.25979, acc: 0.49835


[INFO 2024-12-03 19:40:53,718] [0] Ed: 44800, train_loss: 1.25979, acc: 0.49835


INFO:root:[0] Ed: 48000, train_loss: 1.26039, acc: 0.49783


[INFO 2024-12-03 19:41:06,465] [0] Ed: 48000, train_loss: 1.26039, acc: 0.49783


INFO:root:[0] Ed: 51200, train_loss: 1.25988, acc: 0.49836


[INFO 2024-12-03 19:41:19,218] [0] Ed: 51200, train_loss: 1.25988, acc: 0.49836


INFO:root:[0] Ed: 54400, train_loss: 1.25955, acc: 0.49888


[INFO 2024-12-03 19:41:31,977] [0] Ed: 54400, train_loss: 1.25955, acc: 0.49888


INFO:root:[0] Ed: 57600, train_loss: 1.25865, acc: 0.49984


[INFO 2024-12-03 19:41:44,719] [0] Ed: 57600, train_loss: 1.25865, acc: 0.49984


INFO:root:[0] Ed: 60800, train_loss: 1.25864, acc: 0.49990


[INFO 2024-12-03 19:41:57,448] [0] Ed: 60800, train_loss: 1.25864, acc: 0.49990


INFO:root:[0] Ed: 64000, train_loss: 1.25756, acc: 0.50045


[INFO 2024-12-03 19:42:10,199] [0] Ed: 64000, train_loss: 1.25756, acc: 0.50045


INFO:root:[0] Ed: 67200, train_loss: 1.25647, acc: 0.50132


[INFO 2024-12-03 19:42:22,950] [0] Ed: 67200, train_loss: 1.25647, acc: 0.50132


INFO:root:[0] Ed: 70400, train_loss: 1.25598, acc: 0.50121


[INFO 2024-12-03 19:42:35,698] [0] Ed: 70400, train_loss: 1.25598, acc: 0.50121


INFO:root:[0] Ed: 73600, train_loss: 1.25574, acc: 0.50103


[INFO 2024-12-03 19:42:48,440] [0] Ed: 73600, train_loss: 1.25574, acc: 0.50103


INFO:root:[0] Ed: 76800, train_loss: 1.25455, acc: 0.50150


[INFO 2024-12-03 19:43:01,154] [0] Ed: 76800, train_loss: 1.25455, acc: 0.50150


INFO:root:[0] Ed: 80000, train_loss: 1.25509, acc: 0.50126


[INFO 2024-12-03 19:43:13,879] [0] Ed: 80000, train_loss: 1.25509, acc: 0.50126


INFO:root:[0] Ed: 83200, train_loss: 1.25401, acc: 0.50167


[INFO 2024-12-03 19:43:26,621] [0] Ed: 83200, train_loss: 1.25401, acc: 0.50167


INFO:root:[0] Ed: 86400, train_loss: 1.25370, acc: 0.50171


[INFO 2024-12-03 19:43:39,374] [0] Ed: 86400, train_loss: 1.25370, acc: 0.50171


INFO:root:[0] Ed: 89600, train_loss: 1.25270, acc: 0.50205


[INFO 2024-12-03 19:43:52,143] [0] Ed: 89600, train_loss: 1.25270, acc: 0.50205


INFO:root:[0] Ed: 92800, train_loss: 1.25366, acc: 0.50134


[INFO 2024-12-03 19:44:04,891] [0] Ed: 92800, train_loss: 1.25366, acc: 0.50134


INFO:root:[0] Ed: 96000, train_loss: 1.25325, acc: 0.50121


[INFO 2024-12-03 19:44:17,646] [0] Ed: 96000, train_loss: 1.25325, acc: 0.50121


INFO:root:[0] Ed: 99200, train_loss: 1.25307, acc: 0.50110


[INFO 2024-12-03 19:44:30,382] [0] Ed: 99200, train_loss: 1.25307, acc: 0.50110


INFO:root:[0] Ed: 102400, train_loss: 1.25207, acc: 0.50135


[INFO 2024-12-03 19:44:43,136] [0] Ed: 102400, train_loss: 1.25207, acc: 0.50135


INFO:root:[0] Ed: 105600, train_loss: 1.25066, acc: 0.50200


[INFO 2024-12-03 19:44:55,900] [0] Ed: 105600, train_loss: 1.25066, acc: 0.50200


INFO:root:[0] Ed: 108800, train_loss: 1.25063, acc: 0.50195


[INFO 2024-12-03 19:45:08,657] [0] Ed: 108800, train_loss: 1.25063, acc: 0.50195


INFO:root:[0] Ed: 112000, train_loss: 1.25105, acc: 0.50177


[INFO 2024-12-03 19:45:21,381] [0] Ed: 112000, train_loss: 1.25105, acc: 0.50177


INFO:root:[0] Ed: 115200, train_loss: 1.25088, acc: 0.50189


[INFO 2024-12-03 19:45:34,107] [0] Ed: 115200, train_loss: 1.25088, acc: 0.50189


INFO:root:[0] Ed: 118400, train_loss: 1.24992, acc: 0.50241


[INFO 2024-12-03 19:45:46,839] [0] Ed: 118400, train_loss: 1.24992, acc: 0.50241


INFO:root:[0] Ed: 121600, train_loss: 1.24976, acc: 0.50246


[INFO 2024-12-03 19:45:59,608] [0] Ed: 121600, train_loss: 1.24976, acc: 0.50246


INFO:root:[0] Ed: 124800, train_loss: 1.24942, acc: 0.50252


[INFO 2024-12-03 19:46:12,370] [0] Ed: 124800, train_loss: 1.24942, acc: 0.50252


INFO:root:[0] Ed: 128000, train_loss: 1.24968, acc: 0.50239


[INFO 2024-12-03 19:46:25,097] [0] Ed: 128000, train_loss: 1.24968, acc: 0.50239


INFO:root:[0] Ed: 131200, train_loss: 1.24924, acc: 0.50263


[INFO 2024-12-03 19:46:37,848] [0] Ed: 131200, train_loss: 1.24924, acc: 0.50263


INFO:root:[0] Ed: 134400, train_loss: 1.24909, acc: 0.50271


[INFO 2024-12-03 19:46:50,592] [0] Ed: 134400, train_loss: 1.24909, acc: 0.50271


INFO:root:[0] Ed: 137600, train_loss: 1.24936, acc: 0.50261


[INFO 2024-12-03 19:47:03,368] [0] Ed: 137600, train_loss: 1.24936, acc: 0.50261


INFO:root:[0] Ed: 140800, train_loss: 1.24898, acc: 0.50279


[INFO 2024-12-03 19:47:16,108] [0] Ed: 140800, train_loss: 1.24898, acc: 0.50279


INFO:root:[0] Ed: 144000, train_loss: 1.24899, acc: 0.50248


[INFO 2024-12-03 19:47:28,848] [0] Ed: 144000, train_loss: 1.24899, acc: 0.50248


INFO:root:[0] Ed: 147200, train_loss: 1.24877, acc: 0.50259


[INFO 2024-12-03 19:47:41,588] [0] Ed: 147200, train_loss: 1.24877, acc: 0.50259


INFO:root:[0] Ed: 150400, train_loss: 1.24837, acc: 0.50275


[INFO 2024-12-03 19:47:54,321] [0] Ed: 150400, train_loss: 1.24837, acc: 0.50275


INFO:root:[0] Ed: 153600, train_loss: 1.24784, acc: 0.50286


[INFO 2024-12-03 19:48:07,076] [0] Ed: 153600, train_loss: 1.24784, acc: 0.50286


INFO:root:[0] Ed: 156800, train_loss: 1.24750, acc: 0.50305


[INFO 2024-12-03 19:48:19,831] [0] Ed: 156800, train_loss: 1.24750, acc: 0.50305


INFO:root:[0] Ed: 160000, train_loss: 1.24714, acc: 0.50320


[INFO 2024-12-03 19:48:32,587] [0] Ed: 160000, train_loss: 1.24714, acc: 0.50320


INFO:root:[0] Ed: 163200, train_loss: 1.24705, acc: 0.50317


[INFO 2024-12-03 19:48:45,334] [0] Ed: 163200, train_loss: 1.24705, acc: 0.50317


INFO:root:[0] Ed: 166400, train_loss: 1.24699, acc: 0.50314


[INFO 2024-12-03 19:48:58,074] [0] Ed: 166400, train_loss: 1.24699, acc: 0.50314


INFO:root:[0] Ed: 169600, train_loss: 1.24699, acc: 0.50321


[INFO 2024-12-03 19:49:10,806] [0] Ed: 169600, train_loss: 1.24699, acc: 0.50321


INFO:root:[0] Ed: 172800, train_loss: 1.24674, acc: 0.50321


[INFO 2024-12-03 19:49:23,562] [0] Ed: 172800, train_loss: 1.24674, acc: 0.50321


INFO:root:[0] Ed: 176000, train_loss: 1.24656, acc: 0.50330


[INFO 2024-12-03 19:49:36,310] [0] Ed: 176000, train_loss: 1.24656, acc: 0.50330


INFO:root:[0] Ed: 179200, train_loss: 1.24648, acc: 0.50340


[INFO 2024-12-03 19:49:49,073] [0] Ed: 179200, train_loss: 1.24648, acc: 0.50340


INFO:root:[0] Ed: 182400, train_loss: 1.24636, acc: 0.50336


[INFO 2024-12-03 19:50:01,819] [0] Ed: 182400, train_loss: 1.24636, acc: 0.50336


INFO:root:[0] Ed: 185600, train_loss: 1.24603, acc: 0.50344


[INFO 2024-12-03 19:50:14,566] [0] Ed: 185600, train_loss: 1.24603, acc: 0.50344


INFO:root:[0] Ed: 188800, train_loss: 1.24602, acc: 0.50356


[INFO 2024-12-03 19:50:27,312] [0] Ed: 188800, train_loss: 1.24602, acc: 0.50356


INFO:root:[0] Ed: 192000, train_loss: 1.24607, acc: 0.50354


[INFO 2024-12-03 19:50:40,060] [0] Ed: 192000, train_loss: 1.24607, acc: 0.50354


INFO:root:[0] Ed: 195200, train_loss: 1.24591, acc: 0.50366


[INFO 2024-12-03 19:50:52,820] [0] Ed: 195200, train_loss: 1.24591, acc: 0.50366


INFO:root:[0] Ed: 198400, train_loss: 1.24594, acc: 0.50376


[INFO 2024-12-03 19:51:05,588] [0] Ed: 198400, train_loss: 1.24594, acc: 0.50376


INFO:root:[0] Ed: 201600, train_loss: 1.24560, acc: 0.50400


[INFO 2024-12-03 19:51:18,334] [0] Ed: 201600, train_loss: 1.24560, acc: 0.50400


INFO:root:[0] Ed: 204800, train_loss: 1.24563, acc: 0.50388


[INFO 2024-12-03 19:51:31,061] [0] Ed: 204800, train_loss: 1.24563, acc: 0.50388


INFO:root:[0] Ed: 208000, train_loss: 1.24589, acc: 0.50386


[INFO 2024-12-03 19:51:43,818] [0] Ed: 208000, train_loss: 1.24589, acc: 0.50386


INFO:root:[0] Ed: 211200, train_loss: 1.24586, acc: 0.50385


[INFO 2024-12-03 19:51:56,568] [0] Ed: 211200, train_loss: 1.24586, acc: 0.50385


INFO:root:[0] Ed: 214400, train_loss: 1.24579, acc: 0.50395


[INFO 2024-12-03 19:52:09,328] [0] Ed: 214400, train_loss: 1.24579, acc: 0.50395


INFO:root:[0] Ed: 217600, train_loss: 1.24568, acc: 0.50398


[INFO 2024-12-03 19:52:22,078] [0] Ed: 217600, train_loss: 1.24568, acc: 0.50398


INFO:root:[0] Ed: 220800, train_loss: 1.24545, acc: 0.50404


[INFO 2024-12-03 19:52:34,821] [0] Ed: 220800, train_loss: 1.24545, acc: 0.50404


INFO:root:[0] Ed: 224000, train_loss: 1.24508, acc: 0.50416


[INFO 2024-12-03 19:52:47,561] [0] Ed: 224000, train_loss: 1.24508, acc: 0.50416


INFO:root:[0] Ed: 227200, train_loss: 1.24459, acc: 0.50441


[INFO 2024-12-03 19:53:00,313] [0] Ed: 227200, train_loss: 1.24459, acc: 0.50441


INFO:root:[0] Ed: 230400, train_loss: 1.24437, acc: 0.50449


[INFO 2024-12-03 19:53:13,046] [0] Ed: 230400, train_loss: 1.24437, acc: 0.50449


INFO:root:[0] Ed: 233600, train_loss: 1.24369, acc: 0.50484


[INFO 2024-12-03 19:53:25,777] [0] Ed: 233600, train_loss: 1.24369, acc: 0.50484


INFO:root:Training finish.


[INFO 2024-12-03 19:53:36,372] Training finish.


INFO:root:Model saved to /content/model/epoch-4.pt.


[INFO 2024-12-03 19:53:36,625] Model saved to /content/model/epoch-4.pt.


INFO:root:[0] Ed: 0, train_loss: inf, acc: inf


[INFO 2024-12-03 19:53:36,769] [0] Ed: 0, train_loss: inf, acc: inf


INFO:root:[0] Ed: 3200, train_loss: 1.24824, acc: 0.51906


[INFO 2024-12-03 19:53:49,526] [0] Ed: 3200, train_loss: 1.24824, acc: 0.51906


INFO:root:[0] Ed: 6400, train_loss: 1.23061, acc: 0.51719


[INFO 2024-12-03 19:54:02,259] [0] Ed: 6400, train_loss: 1.23061, acc: 0.51719


INFO:root:[0] Ed: 9600, train_loss: 1.23503, acc: 0.51531


[INFO 2024-12-03 19:54:15,007] [0] Ed: 9600, train_loss: 1.23503, acc: 0.51531


INFO:root:[0] Ed: 12800, train_loss: 1.22861, acc: 0.51617


[INFO 2024-12-03 19:54:27,756] [0] Ed: 12800, train_loss: 1.22861, acc: 0.51617


INFO:root:[0] Ed: 16000, train_loss: 1.23431, acc: 0.51313


[INFO 2024-12-03 19:54:40,536] [0] Ed: 16000, train_loss: 1.23431, acc: 0.51313


INFO:root:[0] Ed: 19200, train_loss: 1.23455, acc: 0.51146


[INFO 2024-12-03 19:54:53,274] [0] Ed: 19200, train_loss: 1.23455, acc: 0.51146


INFO:root:[0] Ed: 22400, train_loss: 1.23301, acc: 0.51219


[INFO 2024-12-03 19:55:06,015] [0] Ed: 22400, train_loss: 1.23301, acc: 0.51219


INFO:root:[0] Ed: 25600, train_loss: 1.23146, acc: 0.51184


[INFO 2024-12-03 19:55:18,747] [0] Ed: 25600, train_loss: 1.23146, acc: 0.51184


INFO:root:[0] Ed: 28800, train_loss: 1.23008, acc: 0.51215


[INFO 2024-12-03 19:55:31,498] [0] Ed: 28800, train_loss: 1.23008, acc: 0.51215


INFO:root:[0] Ed: 32000, train_loss: 1.23091, acc: 0.51116


[INFO 2024-12-03 19:55:44,274] [0] Ed: 32000, train_loss: 1.23091, acc: 0.51116


INFO:root:[0] Ed: 35200, train_loss: 1.23206, acc: 0.51065


[INFO 2024-12-03 19:55:57,025] [0] Ed: 35200, train_loss: 1.23206, acc: 0.51065


INFO:root:[0] Ed: 38400, train_loss: 1.23079, acc: 0.51115


[INFO 2024-12-03 19:56:09,781] [0] Ed: 38400, train_loss: 1.23079, acc: 0.51115


INFO:root:[0] Ed: 41600, train_loss: 1.22817, acc: 0.51226


[INFO 2024-12-03 19:56:22,509] [0] Ed: 41600, train_loss: 1.22817, acc: 0.51226


INFO:root:[0] Ed: 44800, train_loss: 1.22767, acc: 0.51217


[INFO 2024-12-03 19:56:35,249] [0] Ed: 44800, train_loss: 1.22767, acc: 0.51217


INFO:root:[0] Ed: 48000, train_loss: 1.22828, acc: 0.51165


[INFO 2024-12-03 19:56:47,983] [0] Ed: 48000, train_loss: 1.22828, acc: 0.51165


INFO:root:[0] Ed: 51200, train_loss: 1.22747, acc: 0.51160


[INFO 2024-12-03 19:57:00,727] [0] Ed: 51200, train_loss: 1.22747, acc: 0.51160


INFO:root:[0] Ed: 54400, train_loss: 1.22669, acc: 0.51210


[INFO 2024-12-03 19:57:13,491] [0] Ed: 54400, train_loss: 1.22669, acc: 0.51210


INFO:root:[0] Ed: 57600, train_loss: 1.22602, acc: 0.51262


[INFO 2024-12-03 19:57:26,258] [0] Ed: 57600, train_loss: 1.22602, acc: 0.51262


INFO:root:[0] Ed: 60800, train_loss: 1.22583, acc: 0.51278


[INFO 2024-12-03 19:57:38,986] [0] Ed: 60800, train_loss: 1.22583, acc: 0.51278


INFO:root:[0] Ed: 64000, train_loss: 1.22472, acc: 0.51333


[INFO 2024-12-03 19:57:51,732] [0] Ed: 64000, train_loss: 1.22472, acc: 0.51333


INFO:root:[0] Ed: 67200, train_loss: 1.22347, acc: 0.51391


[INFO 2024-12-03 19:58:04,480] [0] Ed: 67200, train_loss: 1.22347, acc: 0.51391


INFO:root:[0] Ed: 70400, train_loss: 1.22289, acc: 0.51389


[INFO 2024-12-03 19:58:17,203] [0] Ed: 70400, train_loss: 1.22289, acc: 0.51389


INFO:root:[0] Ed: 73600, train_loss: 1.22291, acc: 0.51333


[INFO 2024-12-03 19:58:29,960] [0] Ed: 73600, train_loss: 1.22291, acc: 0.51333


INFO:root:[0] Ed: 76800, train_loss: 1.22206, acc: 0.51387


[INFO 2024-12-03 19:58:42,677] [0] Ed: 76800, train_loss: 1.22206, acc: 0.51387


INFO:root:[0] Ed: 80000, train_loss: 1.22238, acc: 0.51411


[INFO 2024-12-03 19:58:55,406] [0] Ed: 80000, train_loss: 1.22238, acc: 0.51411


INFO:root:[0] Ed: 83200, train_loss: 1.22143, acc: 0.51486


[INFO 2024-12-03 19:59:08,134] [0] Ed: 83200, train_loss: 1.22143, acc: 0.51486


INFO:root:[0] Ed: 86400, train_loss: 1.22104, acc: 0.51524


[INFO 2024-12-03 19:59:20,872] [0] Ed: 86400, train_loss: 1.22104, acc: 0.51524


INFO:root:[0] Ed: 89600, train_loss: 1.21996, acc: 0.51569


[INFO 2024-12-03 19:59:33,615] [0] Ed: 89600, train_loss: 1.21996, acc: 0.51569


INFO:root:[0] Ed: 92800, train_loss: 1.22073, acc: 0.51513


[INFO 2024-12-03 19:59:46,375] [0] Ed: 92800, train_loss: 1.22073, acc: 0.51513


INFO:root:[0] Ed: 96000, train_loss: 1.22039, acc: 0.51518


[INFO 2024-12-03 19:59:59,111] [0] Ed: 96000, train_loss: 1.22039, acc: 0.51518


INFO:root:[0] Ed: 99200, train_loss: 1.22029, acc: 0.51500


[INFO 2024-12-03 20:00:11,851] [0] Ed: 99200, train_loss: 1.22029, acc: 0.51500


INFO:root:[0] Ed: 102400, train_loss: 1.21937, acc: 0.51530


[INFO 2024-12-03 20:00:24,589] [0] Ed: 102400, train_loss: 1.21937, acc: 0.51530


INFO:root:[0] Ed: 105600, train_loss: 1.21788, acc: 0.51598


[INFO 2024-12-03 20:00:37,348] [0] Ed: 105600, train_loss: 1.21788, acc: 0.51598


INFO:root:[0] Ed: 108800, train_loss: 1.21788, acc: 0.51587


[INFO 2024-12-03 20:00:50,086] [0] Ed: 108800, train_loss: 1.21788, acc: 0.51587


INFO:root:[0] Ed: 112000, train_loss: 1.21830, acc: 0.51559


[INFO 2024-12-03 20:01:02,822] [0] Ed: 112000, train_loss: 1.21830, acc: 0.51559


INFO:root:[0] Ed: 115200, train_loss: 1.21808, acc: 0.51589


[INFO 2024-12-03 20:01:15,537] [0] Ed: 115200, train_loss: 1.21808, acc: 0.51589


INFO:root:[0] Ed: 118400, train_loss: 1.21692, acc: 0.51637


[INFO 2024-12-03 20:01:28,274] [0] Ed: 118400, train_loss: 1.21692, acc: 0.51637


INFO:root:[0] Ed: 121600, train_loss: 1.21678, acc: 0.51650


[INFO 2024-12-03 20:01:41,045] [0] Ed: 121600, train_loss: 1.21678, acc: 0.51650


INFO:root:[0] Ed: 124800, train_loss: 1.21652, acc: 0.51655


[INFO 2024-12-03 20:01:53,786] [0] Ed: 124800, train_loss: 1.21652, acc: 0.51655


INFO:root:[0] Ed: 128000, train_loss: 1.21681, acc: 0.51646


[INFO 2024-12-03 20:02:06,530] [0] Ed: 128000, train_loss: 1.21681, acc: 0.51646


INFO:root:[0] Ed: 131200, train_loss: 1.21637, acc: 0.51662


[INFO 2024-12-03 20:02:19,278] [0] Ed: 131200, train_loss: 1.21637, acc: 0.51662


INFO:root:[0] Ed: 134400, train_loss: 1.21624, acc: 0.51653


[INFO 2024-12-03 20:02:32,002] [0] Ed: 134400, train_loss: 1.21624, acc: 0.51653


INFO:root:[0] Ed: 137600, train_loss: 1.21642, acc: 0.51651


[INFO 2024-12-03 20:02:44,760] [0] Ed: 137600, train_loss: 1.21642, acc: 0.51651


INFO:root:[0] Ed: 140800, train_loss: 1.21606, acc: 0.51668


[INFO 2024-12-03 20:02:57,499] [0] Ed: 140800, train_loss: 1.21606, acc: 0.51668


INFO:root:[0] Ed: 144000, train_loss: 1.21610, acc: 0.51626


[INFO 2024-12-03 20:03:10,249] [0] Ed: 144000, train_loss: 1.21610, acc: 0.51626


INFO:root:[0] Ed: 147200, train_loss: 1.21591, acc: 0.51648


[INFO 2024-12-03 20:03:22,987] [0] Ed: 147200, train_loss: 1.21591, acc: 0.51648


INFO:root:[0] Ed: 150400, train_loss: 1.21525, acc: 0.51680


[INFO 2024-12-03 20:03:35,718] [0] Ed: 150400, train_loss: 1.21525, acc: 0.51680


INFO:root:[0] Ed: 153600, train_loss: 1.21493, acc: 0.51686


[INFO 2024-12-03 20:03:48,453] [0] Ed: 153600, train_loss: 1.21493, acc: 0.51686


INFO:root:[0] Ed: 156800, train_loss: 1.21451, acc: 0.51716


[INFO 2024-12-03 20:04:01,198] [0] Ed: 156800, train_loss: 1.21451, acc: 0.51716


INFO:root:[0] Ed: 160000, train_loss: 1.21408, acc: 0.51736


[INFO 2024-12-03 20:04:13,940] [0] Ed: 160000, train_loss: 1.21408, acc: 0.51736


INFO:root:[0] Ed: 163200, train_loss: 1.21390, acc: 0.51741


[INFO 2024-12-03 20:04:26,674] [0] Ed: 163200, train_loss: 1.21390, acc: 0.51741


INFO:root:[0] Ed: 166400, train_loss: 1.21388, acc: 0.51733


[INFO 2024-12-03 20:04:39,406] [0] Ed: 166400, train_loss: 1.21388, acc: 0.51733


INFO:root:[0] Ed: 169600, train_loss: 1.21406, acc: 0.51723


[INFO 2024-12-03 20:04:52,145] [0] Ed: 169600, train_loss: 1.21406, acc: 0.51723


INFO:root:[0] Ed: 172800, train_loss: 1.21388, acc: 0.51722


[INFO 2024-12-03 20:05:04,889] [0] Ed: 172800, train_loss: 1.21388, acc: 0.51722


INFO:root:[0] Ed: 176000, train_loss: 1.21369, acc: 0.51734


[INFO 2024-12-03 20:05:17,622] [0] Ed: 176000, train_loss: 1.21369, acc: 0.51734


INFO:root:[0] Ed: 179200, train_loss: 1.21358, acc: 0.51751


[INFO 2024-12-03 20:05:30,360] [0] Ed: 179200, train_loss: 1.21358, acc: 0.51751


INFO:root:[0] Ed: 182400, train_loss: 1.21342, acc: 0.51751


[INFO 2024-12-03 20:05:43,089] [0] Ed: 182400, train_loss: 1.21342, acc: 0.51751


INFO:root:[0] Ed: 185600, train_loss: 1.21311, acc: 0.51758


[INFO 2024-12-03 20:05:55,853] [0] Ed: 185600, train_loss: 1.21311, acc: 0.51758


INFO:root:[0] Ed: 188800, train_loss: 1.21313, acc: 0.51776


[INFO 2024-12-03 20:06:08,587] [0] Ed: 188800, train_loss: 1.21313, acc: 0.51776


INFO:root:[0] Ed: 192000, train_loss: 1.21328, acc: 0.51774


[INFO 2024-12-03 20:06:21,319] [0] Ed: 192000, train_loss: 1.21328, acc: 0.51774


INFO:root:[0] Ed: 195200, train_loss: 1.21314, acc: 0.51794


[INFO 2024-12-03 20:06:34,070] [0] Ed: 195200, train_loss: 1.21314, acc: 0.51794


INFO:root:[0] Ed: 198400, train_loss: 1.21337, acc: 0.51789


[INFO 2024-12-03 20:06:46,814] [0] Ed: 198400, train_loss: 1.21337, acc: 0.51789


INFO:root:[0] Ed: 201600, train_loss: 1.21300, acc: 0.51804


[INFO 2024-12-03 20:06:59,568] [0] Ed: 201600, train_loss: 1.21300, acc: 0.51804


INFO:root:[0] Ed: 204800, train_loss: 1.21313, acc: 0.51797


[INFO 2024-12-03 20:07:12,304] [0] Ed: 204800, train_loss: 1.21313, acc: 0.51797


INFO:root:[0] Ed: 208000, train_loss: 1.21338, acc: 0.51793


[INFO 2024-12-03 20:07:25,041] [0] Ed: 208000, train_loss: 1.21338, acc: 0.51793


INFO:root:[0] Ed: 211200, train_loss: 1.21332, acc: 0.51799


[INFO 2024-12-03 20:07:37,794] [0] Ed: 211200, train_loss: 1.21332, acc: 0.51799


INFO:root:[0] Ed: 214400, train_loss: 1.21325, acc: 0.51798


[INFO 2024-12-03 20:07:50,562] [0] Ed: 214400, train_loss: 1.21325, acc: 0.51798


INFO:root:[0] Ed: 217600, train_loss: 1.21307, acc: 0.51814


[INFO 2024-12-03 20:08:03,308] [0] Ed: 217600, train_loss: 1.21307, acc: 0.51814


INFO:root:[0] Ed: 220800, train_loss: 1.21283, acc: 0.51823


[INFO 2024-12-03 20:08:16,053] [0] Ed: 220800, train_loss: 1.21283, acc: 0.51823


INFO:root:[0] Ed: 224000, train_loss: 1.21253, acc: 0.51833


[INFO 2024-12-03 20:08:28,787] [0] Ed: 224000, train_loss: 1.21253, acc: 0.51833


INFO:root:[0] Ed: 227200, train_loss: 1.21213, acc: 0.51849


[INFO 2024-12-03 20:08:41,521] [0] Ed: 227200, train_loss: 1.21213, acc: 0.51849


INFO:root:[0] Ed: 230400, train_loss: 1.21182, acc: 0.51859


[INFO 2024-12-03 20:08:54,255] [0] Ed: 230400, train_loss: 1.21182, acc: 0.51859


INFO:root:[0] Ed: 233600, train_loss: 1.21125, acc: 0.51888


[INFO 2024-12-03 20:09:07,004] [0] Ed: 233600, train_loss: 1.21125, acc: 0.51888


INFO:root:Training finish.


[INFO 2024-12-03 20:09:17,590] Training finish.


INFO:root:Model saved to /content/model/epoch-5.pt.


[INFO 2024-12-03 20:09:17,845] Model saved to /content/model/epoch-5.pt.


In [None]:
def test(rank, args):
    is_distributed = False

    torch.cuda.set_device(rank)

    if args.load_ckpt_name is not None:
        ckpt_path = get_checkpoint(args.model_dir, args.load_ckpt_name)

    assert ckpt_path is not None, 'No checkpoint found.'
    checkpoint = torch.load(ckpt_path, map_location='cpu')

    subcategory_dict = checkpoint['subcategory_dict']
    category_dict = checkpoint['category_dict']
    word_dict = checkpoint['word_dict']

    dummy_embedding_matrix = np.zeros((len(word_dict) + 1, args.word_embedding_dim))
    model = Model(args, dummy_embedding_matrix, len(category_dict), len(subcategory_dict))
    model.load_state_dict(checkpoint['model_state_dict'])
    logging.info(f"Model loaded from {ckpt_path}")

    if args.enable_gpu:
        model.cuda(rank)

    model.eval()
    torch.set_grad_enabled(False)

    if (args.use_custom_abstract):
      custom_abstract_df = pd.read_csv(args.custom_abstract_dir)
      custom_abstract_dict = custom_abstract_df.set_index('news_id')['abstract'].to_dict()
      news, news_index, category_dict, subcategory_dict, word_cnt = read_custom_abstract(
          os.path.join(args.train_data_dir, 'news.tsv'), custom_abstract_dict)
    else:
      news, news_index, category_dict, subcategory_dict, word_dict = read_news(
          os.path.join(args.train_data_dir, 'news.tsv'), args, mode='train')
    news_title, news_category, news_subcategory, news_abstract = get_doc_input(
        news, news_index, category_dict, subcategory_dict, word_dict, args)
    news_combined = np.concatenate([x for x in [news_title, news_category, news_subcategory, news_abstract] if x is not None], axis=-1)

    news_dataset = NewsDataset(news_combined)
    news_dataloader = DataLoader(news_dataset,
                                 batch_size=args.batch_size,
                                 num_workers=4)

    news_scoring = []
    with torch.no_grad():
        for input_ids in tqdm(news_dataloader):
            input_ids = input_ids.cuda(rank)
            news_vec = model.news_encoder(input_ids)
            news_vec = news_vec.to(torch.device("cpu")).detach().numpy()
            news_scoring.extend(news_vec)

    news_scoring = np.array(news_scoring)
    logging.info("news scoring num: {}".format(news_scoring.shape[0]))

    if rank == 0:
        doc_sim = 0
        for _ in tqdm(range(1000000)):
            i = random.randrange(1, len(news_scoring))
            j = random.randrange(1, len(news_scoring))
            if i != j:
                doc_sim += np.dot(news_scoring[i], news_scoring[j]) / (np.linalg.norm(news_scoring[i]) * np.linalg.norm(news_scoring[j]))
        logging.info(f'News doc-sim: {doc_sim / 1000000}')

    data_file_path = os.path.join(args.test_data_dir, f'behaviors_{rank}.tsv')

    def collate_fn(tuple_list):
        log_vecs = torch.FloatTensor([x[0] for x in tuple_list])
        log_mask = torch.FloatTensor([x[1] for x in tuple_list])
        news_vecs = [x[2] for x in tuple_list]
        labels = [x[3] for x in tuple_list]
        return (log_vecs, log_mask, news_vecs, labels)

    dataset = DatasetTest(data_file_path, news_index, news_scoring, args)
    dataloader = DataLoader(dataset, batch_size=args.batch_size, collate_fn=collate_fn)

    AUC = []
    MRR = []
    nDCG5 = []
    nDCG10 = []

    def print_metrics(rank, cnt, x):
        logging.info("[{}] {} samples: {}".format(rank, cnt, '\t'.join(["{:0.2f}".format(i * 100) for i in x])))

    def get_mean(arr):
        return [np.array(i).mean() for i in arr]

    def get_sum(arr):
        return [np.array(i).sum() for i in arr]

    local_sample_num = 0

    for cnt, (log_vecs, log_mask, news_vecs, labels) in enumerate(dataloader):
        local_sample_num += log_vecs.shape[0]

        if args.enable_gpu:
            log_vecs = log_vecs.cuda(rank, non_blocking=True)
            log_mask = log_mask.cuda(rank, non_blocking=True)

        user_vecs = model.user_encoder(log_vecs, log_mask).to(torch.device("cpu")).detach().numpy()

        for user_vec, news_vec, label in zip(user_vecs, news_vecs, labels):
            if label.mean() == 0 or label.mean() == 1:
                continue

            score = np.dot(news_vec, user_vec)

            auc = roc_auc_score(label, score)
            mrr = mrr_score(label, score)
            ndcg5 = ndcg_score(label, score, k=5)
            ndcg10 = ndcg_score(label, score, k=10)

            AUC.append(auc)
            MRR.append(mrr)
            nDCG5.append(ndcg5)
            nDCG10.append(ndcg10)

        if cnt % args.log_steps == 0:
            print_metrics(rank, local_sample_num, get_mean([AUC, MRR, nDCG5, nDCG10]))

    logging.info('[{}] local_sample_num: {}'.format(rank, local_sample_num))
    if is_distributed:
        local_sample_num = torch.tensor(local_sample_num).cuda(rank)
        dist.reduce(local_sample_num, dst=0, op=dist.ReduceOp.SUM)
        local_metrics_sum = torch.FloatTensor(get_sum([AUC, MRR, nDCG5, nDCG10])).cuda(rank)
        dist.reduce(local_metrics_sum, dst=0, op=dist.ReduceOp.SUM)
        if rank == 0:
            print_metrics('*', local_sample_num, local_metrics_sum / local_sample_num)
    else:
        print_metrics('*', local_sample_num, get_mean([AUC, MRR, nDCG5, nDCG10]))


In [None]:
args.mode = 'test'
args.user_log_mask=True
args.batch_size=128
args.load_ckpt_name= 'epoch-5.pt'
args.prepare=True
if 'test' in args.mode:
        if args.prepare:
            logging.info('Preparing testing data...')
            total_sample_num = prepare_testing_data(args.test_data_dir, args.nGPU)
        else:
            total_sample_num = 0
            for i in range(args.nGPU):
                data_file_path = os.path.join(args.test_data_dir, f'behaviors_{i}.tsv')
                if not os.path.exists(data_file_path):
                    logging.error(f'Splited testing data {data_file_path} for GPU {i} does not exist. Please set the parameter --prepare as True and rerun the code.')
                    exit()
                result = subprocess.getoutput(f'wc -l {data_file_path}')
                total_sample_num += int(result.split(' ')[0])
            logging.info('Skip testing data preparation.')
        logging.info(f'{total_sample_num} testing samples in total.')

        test(0, args)

INFO:root:Preparing testing data...


[INFO 2024-12-03 20:13:56,904] Preparing testing data...


73152it [00:00, 435149.49it/s]
INFO:root:Writing files...


[INFO 2024-12-03 20:13:57,090] Writing files...


INFO:root:73152 testing samples in total.


[INFO 2024-12-03 20:13:57,328] 73152 testing samples in total.


  checkpoint = torch.load(ckpt_path, map_location='cpu')
INFO:root:Model loaded from /content/model/epoch-5.pt


[INFO 2024-12-03 20:13:57,389] Model loaded from /content/model/epoch-5.pt


42416it [00:05, 7546.04it/s]
100%|██████████| 42416/42416 [00:00<00:00, 69105.57it/s]
100%|██████████| 332/332 [00:01<00:00, 195.11it/s]
INFO:root:news scoring num: 42417


[INFO 2024-12-03 20:14:05,418] news scoring num: 42417


100%|██████████| 1000000/1000000 [00:11<00:00, 87974.41it/s]
INFO:root:News doc-sim: 0.10123107692953237


[INFO 2024-12-03 20:14:16,792] News doc-sim: 0.10123107692953237


  log_vecs = torch.FloatTensor([x[0] for x in tuple_list])
INFO:root:[0] 128 samples: 67.71	32.34	35.08	42.27


[INFO 2024-12-03 20:14:17,473] [0] 128 samples: 67.71	32.34	35.08	42.27


INFO:root:[0] 12928 samples: 66.42	32.10	35.28	41.55


[INFO 2024-12-03 20:15:35,529] [0] 12928 samples: 66.42	32.10	35.28	41.55


INFO:root:[0] 25728 samples: 66.27	32.13	35.38	41.50


[INFO 2024-12-03 20:16:57,044] [0] 25728 samples: 66.27	32.13	35.38	41.50


INFO:root:[0] 38528 samples: 66.17	31.90	35.15	41.33


[INFO 2024-12-03 20:18:17,310] [0] 38528 samples: 66.17	31.90	35.15	41.33


INFO:root:[0] 51328 samples: 66.08	31.73	34.94	41.13


[INFO 2024-12-03 20:19:35,375] [0] 51328 samples: 66.08	31.73	34.94	41.13


INFO:root:[0] 64128 samples: 66.12	31.78	34.98	41.21


[INFO 2024-12-03 20:20:54,774] [0] 64128 samples: 66.12	31.78	34.98	41.21


INFO:root:[0] local_sample_num: 73152


[INFO 2024-12-03 20:21:52,435] [0] local_sample_num: 73152


INFO:root:[*] 73152 samples: 66.12	31.80	35.01	41.23


[INFO 2024-12-03 20:21:52,470] [*] 73152 samples: 66.12	31.80	35.01	41.23
