In [2]:
import os.path as op
import random
import pandas as pd
import torch
from torch.utils.data import Dataset, DataLoader

class MovielensData(Dataset):
    def __init__(self, data_dir=r'data/ref/movielens',
                 stage=None,
                 cans_num=10,
                 sep=", ",
                 no_augment=True):
        self.__dict__.update(locals())
        self.aug = (stage == 'train') and not no_augment
        self.padding_item_id = 4606
        self.padding_rating = 0
        self.check_files()

    def __len__(self):
        return len(self.session_data['seq'])

    def __getitem__(self, i):
        temp = self.session_data.iloc[i]
        candidates = self.negative_sampling(temp['seq_unpad'], temp['next'])
        cans_name = [self.item_id2name.get(can, "Unknown") for can in candidates]
        sample = {
            'seq': temp['seq'],
            'seq_name': temp['seq_title'],
            'len_seq': temp['len_seq'],
            'seq_str': self.sep.join(temp['seq_title']),
            'cans': candidates,
            'cans_name': cans_name,
            'cans_str': self.sep.join(cans_name),
            'len_cans': self.cans_num,
            'item_id': temp['next'],
            'item_name': temp['next_item_name'],
            'correct_answer': temp['next_item_name'],
            'most_similar_seq': temp['most_similar_seq'],
            'most_similar_seq_next': temp['most_similar_seq_next'],
            'most_similar_seq_name': temp['most_similar_seq_name'],
            'most_similar_seq_next_name': temp['most_similar_seq_next_name'],
        }
        return sample

    def negative_sampling(self, seq_unpad, next_item):
        canset = [i for i in list(self.item_id2name.keys()) if i not in seq_unpad and i != next_item]
        candidates = random.sample(canset, self.cans_num - 1) + [next_item]
        random.shuffle(candidates)
        return candidates

    def check_files(self):
        self.item_id2name = self.get_movie_id2name()
        if self.stage == 'train':
            filename = "similar_train_data.df"
        elif self.stage == 'val':
            filename = "similar_val_data.df"
        elif self.stage == 'test':
            filename = "similar_test_data.df"
        data_path = op.join(self.data_dir, filename)
        self.session_data = self.session_data4frame(data_path, self.item_id2name)

    def get_mv_title(self, s):
        sub_list = [", The", ", A", ", An"]
        for sub_s in sub_list:
            if sub_s in s:
                return sub_s[2:] + " " + s.replace(sub_s, "")
        return s

    def get_movie_id2name(self):
        movie_id2name = dict()
        item_path = op.join(self.data_dir, 'id2name.txt')
        with open(item_path, 'r') as f:
            for l in f.readlines():
                ll = l.strip('\n').split('::')
                movie_id2name[int(ll[0])] = ll[1].strip()
        return movie_id2name

    def session_data4frame(self, datapath, movie_id2name):
        print(f"Loading data from {datapath}\n")
        train_data = pd.read_pickle(datapath)
        print("Initial train_data loaded\n")
        print(train_data.head(), "\n")
        train_data = train_data[train_data['len_seq'] >= 3]

        def remove_padding(xx):
            x = xx[:]
            for i in range(10):
                try:
                    x.remove(self.padding_item_id)
                except ValueError:
                    break
            return x

        train_data['seq_unpad'] = train_data['seq'].apply(remove_padding)
        print("Padding removed\n")
        print(train_data[['seq', 'seq_unpad']].head(), "\n")

        def seq_to_title(x):
            titles = []
            for x_i in x:
                if x_i in movie_id2name:
                    titles.append(movie_id2name[x_i])
                else:
                    print(f"KeyError: Movie ID {x_i} not found in movie_id2name.\n")
                    titles.append("Unknown")
            return titles

        train_data['seq_title'] = train_data['seq_unpad'].apply(seq_to_title)
        print("Titles added\n")
        print(train_data[['seq_unpad', 'seq_title']].head(), "\n")

        def next_item_title(x):
            if x in movie_id2name:
                return movie_id2name[x]
            else:
                print(f"KeyError: Movie ID {x} not found in movie_id2name.\n")
                return "Unknown"

        train_data['next_item_name'] = train_data['next'].apply(next_item_title)
        print("Next item titles added\n")
        print(train_data[['next', 'next_item_name']].head(), "\n")

        def get_id(x):
            if isinstance(x, tuple):
                return x[0]
            return x

        def get_id_list(x):
            if isinstance(x[0], tuple):
                return [i[0] for i in x]
            return x

        train_data['next'] = train_data['next'].apply(get_id)
        train_data['seq'] = train_data['seq'].apply(get_id_list)
        train_data['seq_unpad'] = train_data['seq_unpad'].apply(get_id_list)
        train_data['most_similar_seq_next'] = train_data['most_similar_seq_next'].apply(get_id)
        train_data['most_similar_seq'] = train_data['most_similar_seq'].apply(get_id)
        print("IDs converted\n")
        print(train_data.head(), "\n")

        return train_data

# 假设数据目录和阶段已经被正确设置
dataset = MovielensData(data_dir='/workspace/LLaRA/data/ref/movielens', stage='train')
dataloader = DataLoader(dataset, batch_size=1, shuffle=True)  # 设置批量大小为1

# 获取一个批次的数据并打印
for batch in dataloader:
    print(batch, "\n")
    break  # 只打印第一个批次

# Add logging in the training loop
def train_model(dataloader):
    for i, batch in enumerate(dataloader):
        print(f"Batch {i}: {batch}\n")
        break  # 只打印一个批次的数据

# Train the model and debug
train_model(dataloader)


Loading data from /workspace/LLaRA/data/ref/movielens/similar_train_data.df

Initial train_data loaded

                                                 seq  len_seq  next  \
0  [4606, 4606, 4606, 4606, 4606, 4606, 4606, 460...        1     0   
1  [0, 4606, 4606, 4606, 4606, 4606, 4606, 4606, ...        1     1   
2  [0, 1, 4606, 4606, 4606, 4606, 4606, 4606, 460...        2     2   
3  [0, 1, 2, 4606, 4606, 4606, 4606, 4606, 4606, ...        3     3   
4   [0, 1, 2, 3, 4606, 4606, 4606, 4606, 4606, 4606]        4     4   

                                    movie_names_only  \
0  [4606, 4606, 4606, 4606, 4606, 4606, 4606, 460...   
1  [Morcheeba, 4606, 4606, 4606, 4606, 4606, 4606...   
2  [Morcheeba, Enigma, 4606, 4606, 4606, 4606, 46...   
3  [Morcheeba, Enigma, CafÃ© Del Mar, 4606, 4606,...   
4  [Morcheeba, Enigma, CafÃ© Del Mar, Fleetwood M...   

                              most_similar_seq_index  \
0  [18110.0, 17948.0, 40368.0, 17925.0, 17921.0, ...   
1  [48119.0, 1594.0,