In [None]:
import pandas as pd
import json
import torch
import torch.nn as nn
from transformers import BertTokenizer, BertModel
import numpy as np
import random
from tqdm import tqdm
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" # 서버적용 코드
os.environ["CUDA_VISIBLE_DEVICES"]= "0"


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load the CSV file
filtered_data = pd.read_csv('./filtered_data_with_prompt.csv')
filtered_song_meta = pd.read_csv('./filtered_song_meta.csv')

# Extract song IDs from filtered_data
filtered_song_ids = set()
for songs in filtered_data['songs']:
    filtered_song_ids.update(json.loads(songs))

# Convert song meta data to dictionary
song_meta_dict = filtered_song_meta.set_index('id').to_dict('index')

# Convert string embeddings to list of floats
for song_id, song_info in song_meta_dict.items():
    song_info['embedding'] = json.loads(song_info['embedding'])

In [None]:
# Multilingual BERT 모델과 토크나이저 로드
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
bert_model = BertModel.from_pretrained('bert-base-multilingual-cased').to(device)

# 함수 정의: 텍스트 인코딩
def encode_text(text):
    inputs = tokenizer(text, return_tensors='pt', padding=True, truncation=True).to(device)
    outputs = bert_model(**inputs)
    return outputs.last_hidden_state.mean(dim=1).detach()



In [None]:
class QNetwork(nn.Module):
    def __init__(self, text_embedding_dim, genre_embedding_dim, hidden_dim):
        super(QNetwork, self).__init__()
        self.fc1 = nn.Linear(text_embedding_dim * 2 + genre_embedding_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, text_embedding_dim)  # 768 차원의 벡터 출력

    def forward(self, album_emb, song_emb, genre_emb):
        x = torch.cat((album_emb, song_emb, genre_emb), dim=1)
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x


In [None]:
class MusicRecommendationEnv:
    def __init__(self, playlists, song_meta, num_samples=50):
        self.playlists = playlists
        self.song_meta = song_meta
        self.num_samples = num_samples
        self.current_playlist_index = 0
        self.recommended_songs = set()
        self.action_space = len(self.song_meta)

    def reset(self):
        self.current_playlist_index = np.random.randint(len(self.playlists))
        playlist = self.playlists[self.current_playlist_index]
        self.recommended_songs = set()
        return self._sample_songs(playlist), playlist['prompt']

    def step(self, action):
        if action not in self.song_meta:
            raise ValueError(f"Action {action} is not in song_meta.")
        playlist = self.playlists[self.current_playlist_index]
        reward = calculate_reward(action, playlist)
        done = len(self.recommended_songs) >= 10
        next_state, next_prompt = self._sample_songs(playlist), playlist['prompt']
        song_meta = self.song_meta[action]
        album_emb = encode_text(song_meta['album_name'])
        song_emb = encode_text(song_meta['song_name'])
        genre_emb = torch.tensor(song_meta['embedding'], dtype=torch.float32).to(device)
        genre_emb = genre_emb.unsqueeze(0) if genre_emb.dim() == 1 else genre_emb

        self.recommended_songs.add(action)
        return next_state, next_prompt, album_emb, song_emb, genre_emb, reward, done

    def _sample_songs(self, playlist):
        sampled_songs = random.sample(json.loads(playlist['songs']), min(20, len(json.loads(playlist['songs']))))
        other_songs = [song_id for song_id in self.song_meta.keys() if song_id not in sampled_songs and song_id not in self.recommended_songs]
        sampled_songs += random.sample(other_songs, self.num_samples - len(sampled_songs))
        return sampled_songs

def calculate_reward(recommended_song, playlist):
    if recommended_song in json.loads(playlist['songs']):
        return 1
    else:
        return -0.1
    


env = MusicRecommendationEnv(filtered_data.to_dict('records'), song_meta_dict)\



In [None]:
# DQN 모델 초기화
text_embedding_dim = 768  # BERT base model output size
genre_embedding_dim = 8  # 장르 임베딩 크기
hidden_dim = 256

model_path = './weights/from_simple_20_100_1000.pth'  # 저장된 모델 파일 경로
model = QNetwork(text_embedding_dim, genre_embedding_dim, hidden_dim).to(device)
model.load_state_dict(torch.load(model_path))

def train_dqn(env, model, episodes, gamma, epsilon, epsilon_min, epsilon_decay, batch_size, weight=0.01):
    optimizer = torch.optim.Adam(model.parameters())
    criterion = nn.MSELoss()
    memory = []
    rewards_per_episode = []

    for episode in tqdm(range(episodes), desc="Training DQN"):
        state, prompt = env.reset()
        prompt_emb = encode_text(prompt).view(1, -1)  # 프롬프트를 벡터로 변환
        total_reward = 0

        for t in range(10):  # 각 프롬프트에 대해 10개의 노래 추천
            if np.random.rand() <= epsilon:
                while True:
                    action_idx = np.random.randint(len(state))
                    action = state[action_idx]
                    if action not in env.recommended_songs:
                        break
            else:
                max_similarity = -float('inf')
                action = None
                action_idx = None
                for idx, song_id in enumerate(state):
                    if song_id in env.song_meta and song_id not in env.recommended_songs:
                        song_meta = env.song_meta[song_id]
                        album_emb = encode_text(song_meta['album_name'])
                        song_emb = encode_text(song_meta['song_name'])
                        genre_emb = torch.tensor(song_meta['embedding'], dtype=torch.float32).to(device)
                        genre_emb = genre_emb.unsqueeze(0) if genre_emb.dim() == 1 else genre_emb
                        song_vector = model(album_emb, song_emb, genre_emb)

                        # song_vector와 prompt_emb의 코사인 유사도를 계산
                        similarity = torch.cosine_similarity(song_vector, prompt_emb, dim=1).item()
                        if similarity > max_similarity:
                            max_similarity = similarity
                            action = song_id
                            action_idx = idx

            if action is None:
                continue

            next_state, next_prompt, album_emb, song_emb, genre_emb, reward, done = env.step(action)
            total_reward += reward

            # 프롬프트 임베딩 업데이트
            song_vector = model(album_emb, song_emb, genre_emb)
            prompt_emb = prompt_emb + weight * song_vector * reward

            next_state_vector = prompt_emb  # 프롬프트가 변경되었으므로 업데이트된 prompt_emb 사용

            memory.append((album_emb, song_emb, genre_emb, prompt_emb, reward, next_state_vector, done, action))

            if done:
                break

            if len(memory) > batch_size:
                minibatch = random.sample(memory, batch_size)
                for album_emb, song_emb, genre_emb, prompt_emb, reward, next_state_vec, done, action in minibatch:
                    target = reward
                    if not done:
                        with torch.no_grad():
                            next_q_values = model(next_state_vec, song_emb, genre_emb)
                            target += gamma * next_q_values.max().item()

                    # 현재 상태에서의 Q-값
                    current_q_values = model(album_emb, song_emb, genre_emb)

                    # 타겟 Q-값
                    target_q_values = current_q_values.clone()
                    target_q_values[0][action_idx] = target

                    # 손실 계산 및 역전파
                    model.zero_grad()
                    loss = criterion(current_q_values, target_q_values)
                    loss.backward()
                    optimizer.step()

        rewards_per_episode.append(total_reward)
        if epsilon > epsilon_min:
            epsilon *= epsilon_decay

        if (episode + 1) % 10 == 0:
            average_reward = np.mean(rewards_per_episode[-10:])
            print(f"Episode {episode + 1}/{episodes} - Average Reward (last 10 episodes): {average_reward:.2f}")

    return model, rewards_per_episode

# 환경 초기화
env = MusicRecommendationEnv(filtered_data.to_dict('records'), song_meta_dict,500)

# DQN 학습
episodes = 100000
gamma = 0.95
epsilon = 1.0
epsilon_min = 0.01
epsilon_decay = 0.995
batch_size = 256

trained_model, rewards_per_episode = train_dqn(env, model, episodes, gamma, epsilon, epsilon_min, epsilon_decay, batch_size)

In [None]:
# 모델 저장
weight_name = "from_simple_20_500_10000"
torch.save(trained_model.state_dict(), f'./weights/{weight_name}.pth')
print(f"Model saved to '{weight_name}.pth'")

In [None]:
import matplotlib.pyplot as plt

# 에피소드별 보상 시각화
plt.plot(rewards_per_episode)
plt.xlabel('Episode')
plt.ylabel('Total Reward')
plt.title('Total Reward per Episode')
plt.show()