In [1]:
import torch
import torch.nn as nn
from dataset.msrvtt_dataloader import MSRVTT_DataLoader
from model.fusion_model import EverythingAtOnceModel
from gensim.models.keyedvectors import KeyedVectors
from torch.utils.data import DataLoader

import argparse

  from .autonotebook import tqdm as notebook_tqdm


In [4]:
args = argparse.Namespace(
    we_path='./data/GoogleNews-vectors-negative300.bin',
    data_path='C:/Users/heeryung/code/24w_deep_daiv/msrvtt_category_test.pkl',
    checkpoint_path='C:/Users/heeryung/code/24w_deep_daiv/ckpt/trial5_classifier/epoch200.pth',
    token_projection='projection_net',
    use_softmax=False,
    use_cls_token=False,
    num_classes=20,
    batch_size=16
)

In [5]:
checkpoint = torch.load(args.checkpoint_path)

we = None 
we = KeyedVectors.load_word2vec_format(args.we_path, binary=True)

dataset = MSRVTT_DataLoader(data_path=args.data_path, we=we)
data_loader = DataLoader(dataset, batch_size=args.batch_size)

net = EverythingAtOnceModel(args).cuda()
optimizer = torch.optim.Adam(net.parameters(), lr =0.001)

net.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
net.eval()

def get_soft_voting(va, at, tv):
    # Soft voting by averaging the probabilities
    soft_vote = (va + at + tv) / 3
    _, soft_vote_preds = torch.max(soft_vote, 1)
    return soft_vote_preds

def get_hard_voting(va_preds, at_preds, tv_preds):
    # Hard voting by selecting the most frequent prediction
    combined_preds = torch.stack((va_preds, at_preds, tv_preds), dim=1)
    hard_vote, _ = torch.mode(combined_preds, dim=1)
    return hard_vote

def get_predictions(va, at, tv):
    _, va_preds = torch.max(va, 1)
    _, at_preds = torch.max(at, 1)
    _, tv_preds = torch.max(tv, 1)
    return va_preds, at_preds, tv_preds

def calculate_accuracy(predictions, labels):
    correct = (predictions == labels).sum().item()
    total = labels.size(0)
    accuracy = correct / total
    return accuracy

total_samples = 0
total_accuracy = 0
total_video_correct = 0
total_audio_correct = 0
total_text_correct = 0
total_hard_vote_correct = 0
total_soft_vote_correct = 0

In [6]:
for data in data_loader:
    video = data['video'].cuda()
    audio = data['audio'].cuda()
    text = data['text'].cuda()
    nframes = data['nframes'].cuda()
    category = data['category'].cuda() # [batch_size,]

    video = video.view(-1, video.shape[-1])
    audio = audio.view(-1, audio.shape[-2], audio.shape[-1])
    text = text.view(-1, text.shape[-2], text.shape[-1])

    pred = net(video, audio, nframes, text, category) # [batch_size, 20]
    print(pred)
    pred_category = torch.argmax(pred, dim=1) # [batch_size,]
    accuracy = torch.mean((pred_category == category).float()) # [batch_size,]
    print(pred_category, '/', category)

    total_accuracy += accuracy

# Calculate final accuracies
accuracy = total_accuracy / len(data_loader)

print("Accuracy:", accuracy)

tensor([[ 0.3530, -0.7100, -0.0696,  0.8187, -0.0030, -0.5496, -0.4708,  0.6619,
         -0.5395,  0.5044,  0.1096, -0.6115, -0.3472,  0.2192,  0.0657, -1.1334,
          0.2847, -0.4040, -0.0351, -1.1323],
        [ 0.3530, -0.7100, -0.0696,  0.8187, -0.0030, -0.5496, -0.4708,  0.6619,
         -0.5395,  0.5044,  0.1096, -0.6115, -0.3472,  0.2192,  0.0657, -1.1334,
          0.2847, -0.4040, -0.0351, -1.1323],
        [ 0.3530, -0.7100, -0.0696,  0.8187, -0.0030, -0.5496, -0.4708,  0.6619,
         -0.5395,  0.5044,  0.1096, -0.6115, -0.3472,  0.2192,  0.0657, -1.1334,
          0.2847, -0.4040, -0.0351, -1.1323],
        [ 0.3530, -0.7100, -0.0696,  0.8187, -0.0030, -0.5496, -0.4708,  0.6619,
         -0.5395,  0.5044,  0.1096, -0.6115, -0.3472,  0.2192,  0.0657, -1.1334,
          0.2847, -0.4040, -0.0351, -1.1323],
        [ 0.3530, -0.7100, -0.0696,  0.8187, -0.0030, -0.5496, -0.4708,  0.6619,
         -0.5395,  0.5044,  0.1096, -0.6115, -0.3472,  0.2192,  0.0657, -1.1334,
      

KeyboardInterrupt: 

In [7]:
pred

tensor([[ 0.3530, -0.7100, -0.0696,  0.8187, -0.0030, -0.5496, -0.4708,  0.6619,
         -0.5395,  0.5044,  0.1096, -0.6115, -0.3472,  0.2192,  0.0657, -1.1334,
          0.2847, -0.4040, -0.0351, -1.1323],
        [ 0.3530, -0.7100, -0.0696,  0.8187, -0.0030, -0.5496, -0.4708,  0.6619,
         -0.5395,  0.5044,  0.1096, -0.6115, -0.3472,  0.2192,  0.0657, -1.1334,
          0.2847, -0.4040, -0.0351, -1.1323],
        [ 0.3530, -0.7100, -0.0696,  0.8187, -0.0030, -0.5496, -0.4708,  0.6619,
         -0.5395,  0.5044,  0.1096, -0.6115, -0.3472,  0.2192,  0.0657, -1.1334,
          0.2847, -0.4040, -0.0351, -1.1323],
        [ 0.3530, -0.7100, -0.0696,  0.8187, -0.0030, -0.5496, -0.4708,  0.6619,
         -0.5395,  0.5044,  0.1096, -0.6115, -0.3472,  0.2192,  0.0657, -1.1334,
          0.2847, -0.4040, -0.0351, -1.1323],
        [ 0.3530, -0.7100, -0.0696,  0.8187, -0.0030, -0.5496, -0.4708,  0.6619,
         -0.5395,  0.5044,  0.1096, -0.6115, -0.3472,  0.2192,  0.0657, -1.1334,
      

In [8]:
_, p = torch.max(pred.data, 1)

In [9]:
p

tensor([3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3], device='cuda:0')

In [10]:
_

tensor([0.8187, 0.8187, 0.8187, 0.8187, 0.8187, 0.8187, 0.8187, 0.8187, 0.8187,
        0.8187, 0.8187, 0.8187, 0.8187, 0.8187, 0.8187, 0.8187],
       device='cuda:0')

In [11]:
import torch
a = torch.randn(16,64,256)

In [15]:
a.size(0)

16

In [None]:
class Classifier(nn.Module):
    def __init__(self):
        super(Classifier, self).__init__()
        self.video_fc = nn.Linear(video_dim * embedding_size, 512)
        self.audio_fc = nn.Linear(audio_dim * embedding_size, 512)
        self.text_fc = nn.Linear(text_dim * embedding_size, 512)
        self.final_fc = nn.Linear(512 * 3, num_classes)  # 3: 비디오, 오디오, 텍스트 입력의 수

    def forward(self, video, audio, text):
        video_flat = video.view(video.size(0), -1)
        audio_flat = audio.view(audio.size(0), -1)
        text_flat = text.view(text.size(0), -1)
        
        video_out = torch.relu(self.video_fc(video_flat))
        audio_out = torch.relu(self.audio_fc(audio_flat))
        text_out = torch.relu(self.text_fc(text_flat))
        
        combined = torch.cat((video_out, audio_out, text_out), dim=1)
        logits = self.final_fc(combined)
        return logits