In [29]:
import os
import pickle
import time
from pathlib import Path

import cv2
import numpy as np
import pandas as pd
import torch
import torchvision.models as models
from sklearn.preprocessing import LabelEncoder
from torch import nn
from torch.utils.data import DataLoader, Dataset
from torchvision.transforms import v2 as transforms

labels1 = pd.read_csv('synth3/labels.csv')
labels1['filename']=labels1['filename'].str.replace('/home/roman/gagarin/data_generation/data/final_dataset1', 'synth3')
labels2 = pd.read_csv('synth3/labels2.csv')
labels2['filename']=labels2['filename'].str.replace('/home/roman/gagarin/data_generation/data/final_dataset', 'synth3')

labels = pd.concat([labels1, labels2]).sample(frac=1.0)

le = LabelEncoder()
labels['anomaly'] = le.fit_transform(labels['anomaly'])

In [33]:
class VideoDataset(Dataset):
    def __init__(self, labels, transform=None):
        """
        Args:
            labels (DataFrame): DataFrame с двумя колонками, первая с путями, вторая с метками.
            transform (callable, optional): Необязательные трансформации для применения к каждому кадру.
        """
        self.labels = labels
        self.transform = transform
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        filename, label = self.labels.iloc[idx]
        path = filename
        target_path = path.replace('bin', 'h264')

        # Копирование файла с изменением формата
        if not Path(target_path).exists():
            with open(path, 'rb') as f:
                data = f.read()
                with open(target_path, 'wb') as f:
                    
                    f.write(data)
        
        # Чтение видео
        video = cv2.VideoCapture(target_path)
        seq = []
        cnt = 0
        while cnt < 2:
            ret, frame = video.read()
            if not ret:
                break
            
            frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) 
            frame = cv2.resize(frame, (384, 384), interpolation=cv2.INTER_AREA)
            if self.transform:
                frame = self.transform(frame)
            else:
                frame = torch.from_numpy(frame)
            seq.append(frame)
            cnt += 1
        
        video.release()
        # os.remove(target_path)  
        
        seq = torch.stack(seq, dim=0)
        return seq, label

test_size = int(len(labels)*0.15)
test_list = labels[-test_size:]
train_list = labels[:-test_size]

In [35]:
class Model(nn.Module):
    def __init__(self, n_class=2):
        super(Model, self).__init__()
        
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.fc = torch.nn.Identity()
        
        self.fc = nn.Sequential(
            nn.Linear(512, 256,),
            nn.ReLU(),
            nn.Linear(256, n_class)
        )
        
    def forward(self, x):
        batch_size, seq_size = x.shape[:2]
        x = self.resnet(x.reshape(batch_size*seq_size, 3, 224, 224))
        x = x.reshape(batch_size, seq_size, 512)
        x = x.mean(1)
        x = self.fc(x)
        return x
    
device = 'cuda'
model = Model(5).to(device)



In [36]:
def train(dataloader, criterion, optimizer, epoch):
    model.train()
    total_acc, total_count = 0, 0
    log_interval = 1

    for idx, (data, label) in enumerate(dataloader):
        optimizer.zero_grad()
        data = data.to(device)
        label = label.to(dtype=torch.long, device=device)
        predicted_label = model(data)

        loss = criterion(predicted_label, label)
        
        loss.backward()

        optimizer.step()
        total_acc += (predicted_label.argmax(1) == label).sum().item()
        total_count += label.size(0)

        print(
            "| epoch {:3d} | {:5d}/{:5d} batches "
            "| accuracy {}".format(
                epoch, idx, len(dataloader), total_acc / total_count
            )
        )
        total_acc, total_count = 0, 0
        
def evaluate(dataloader):
    model.eval()
    total_acc, total_count = 0, 0

    with torch.no_grad():
        for idx, (data, label) in enumerate(dataloader):
            
            data, label = data.to(device), label.to(dtype=torch.long, device=device)
            predicted_label = model(data)
            total_acc += (predicted_label.argmax(1) == label).sum().item()
            total_count += label.size(0)
    return total_acc / total_count

In [37]:
bs = 128

transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.RandomResizedCrop(224),
                    transforms.RandomHorizontalFlip(),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
                ])

val_transform = transforms.Compose(
                [
                    transforms.ToTensor(),
                    transforms.Resize(224),
                    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                 std=[0.229, 0.224, 0.225])
                ])

train_dataset = VideoDataset(train_list, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=bs, shuffle=True)

test_dataset = VideoDataset(test_list, transform=val_transform)
test_loader = DataLoader(test_dataset, batch_size=bs, shuffle=True)

EPOCHS = 3 
LR = 1e-3

criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters()) 
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.7)
total_accu = None

for epoch in range(1, EPOCHS + 1):
    epoch_start_time = time.time()
    train(train_loader, criterion, optimizer, epoch)
    
    accu_val = evaluate(test_loader)

    print("-" * 59)
    print(
        "| end of epoch {:3d} | time: {:5.2f}s | "
        "valid accuracy {:8.3f} ".format(
            epoch, time.time() - epoch_start_time, accu_val
        )
    )
    print("-" * 59)
    
    scheduler.step()





| epoch   1 |     0/   96 batches | accuracy 0.4375
| epoch   1 |     1/   96 batches | accuracy 0.828125
| epoch   1 |     2/   96 batches | accuracy 0.8203125
| epoch   1 |     3/   96 batches | accuracy 0.8359375
| epoch   1 |     4/   96 batches | accuracy 0.796875
| epoch   1 |     5/   96 batches | accuracy 0.953125
| epoch   1 |     6/   96 batches | accuracy 0.9453125
| epoch   1 |     7/   96 batches | accuracy 0.9140625
| epoch   1 |     8/   96 batches | accuracy 0.875
| epoch   1 |     9/   96 batches | accuracy 0.921875
| epoch   1 |    10/   96 batches | accuracy 0.953125
| epoch   1 |    11/   96 batches | accuracy 0.90625
| epoch   1 |    12/   96 batches | accuracy 0.9296875
| epoch   1 |    13/   96 batches | accuracy 0.90625
| epoch   1 |    14/   96 batches | accuracy 0.9140625
| epoch   1 |    15/   96 batches | accuracy 0.9609375
| epoch   1 |    16/   96 batches | accuracy 0.90625
| epoch   1 |    17/   96 batches | accuracy 0.9609375
| epoch   1 |    18/   96 ba

In [None]:
with open('resnet/model.pt', 'wb') as f:
    torch.save(model.state_dict(), f)
    

with open("resnet/le.pkl", "wb") as f:
    pickle.dump(le, f)