In [None]:
import easydict
import os
import sys
from PIL import Image
import tqdm
import shutil

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim import lr_scheduler
import torchvision.transforms as transforms
import torch.utils.data as data
from torchvision import transforms
import cv2
from glob import glob
import pandas as pd

from PIL import Image
from tqdm import tqdm
import dlib
from sklearn.model_selection import train_test_split

In [None]:
args = easydict.EasyDict({
    "num_workers": 32,

    "learning_rate": 0.001,
    "num_epochs": 1,
    "batch_size": 32,

    "save_fn": "deepfake_c0_xception_tuned.pth.tar",
})

In [None]:
"""
Author: Andreas Rössler,
Implemented in https://github.com/ondyari/FaceForensics under MIT license
"""

class SeparableConv2d(nn.Module):
    def __init__(self,in_channels,out_channels,kernel_size=1,stride=1,padding=0,dilation=1,bias=False):
        super(SeparableConv2d,self).__init__()

        self.conv1 = nn.Conv2d(in_channels,in_channels,kernel_size,stride,padding,dilation,groups=in_channels,bias=bias)
        self.pointwise = nn.Conv2d(in_channels,out_channels,1,1,0,1,1,bias=bias)

    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x


class Block(nn.Module):
    def __init__(self,in_filters,out_filters,reps,strides=1,start_with_relu=True,grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides!=1:
            self.skip = nn.Conv2d(in_filters,out_filters,1,stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip=None

        self.relu = nn.ReLU(inplace=True)
        rep=[]

        filters=in_filters
        if grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps-1):
            rep.append(self.relu)
            rep.append(SeparableConv2d(filters,filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(filters))

        if not grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters,out_filters,3,stride=1,padding=1,bias=False))
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3,strides,1))
        self.rep = nn.Sequential(*rep)

    def forward(self,inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x+=skip
        return x


class Xception(nn.Module):
    def __init__(self, num_classes=1000):
        super(Xception, self).__init__()
        self.num_classes = num_classes

        self.conv1 = nn.Conv2d(3,32,3,2,0,bias=False)
        self.bn1 = nn.BatchNorm2d(32)
        self.relu = nn.ReLU(inplace=True)

        self.conv2 = nn.Conv2d(32,64,3,bias=False)
        self.bn2 = nn.BatchNorm2d(64)

        self.block1=Block(64,128,2,2,start_with_relu=False,grow_first=True)
        self.block2=Block(128,256,2,2,start_with_relu=True,grow_first=True)
        self.block3=Block(256,728,2,2,start_with_relu=True,grow_first=True)

        self.block4=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block5=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block6=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block7=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block8=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block9=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block10=Block(728,728,3,1,start_with_relu=True,grow_first=True)
        self.block11=Block(728,728,3,1,start_with_relu=True,grow_first=True)

        self.block12=Block(728,1024,2,2,start_with_relu=True,grow_first=False)

        self.conv3 = SeparableConv2d(1024,1536,3,1,1)
        self.bn3 = nn.BatchNorm2d(1536)

        self.conv4 = SeparableConv2d(1536,2048,3,1,1)
        self.bn4 = nn.BatchNorm2d(2048)

        self.fc = nn.Linear(2048, num_classes)

    def features(self, input):
        x = self.conv1(input)
        x = self.bn1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)

        x = self.block1(x)
        x = self.block2(x)
        x = self.block3(x)
        x = self.block4(x)
        x = self.block5(x)
        x = self.block6(x)
        x = self.block7(x)
        x = self.block8(x)
        x = self.block9(x)
        x = self.block10(x)
        x = self.block11(x)
        x = self.block12(x)

        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)

        x = self.conv4(x)
        x = self.bn4(x)
        return x

    def logits(self, features):
        x = self.relu(features)

        x = F.adaptive_avg_pool2d(x, (1, 1)) 
        x = x.view(x.size(0), -1)
        x = self.last_linear(x)
        return x

    def forward(self, input):
        x = self.features(input)
        x = self.logits(x)
        return x

## 기존 Xception에 Dropout만 추가
class xception(nn.Module):
    def __init__(self, num_out_classes=2, dropout=0.5):
        super(xception, self).__init__()

        self.model = Xception(num_classes=num_out_classes)
        self.model.last_linear = self.model.fc
        del self.model.fc

        num_ftrs = self.model.last_linear.in_features
        if not dropout:
            self.model.last_linear = nn.Linear(num_ftrs, num_out_classes)
        else:            
            self.model.last_linear = nn.Sequential(
                nn.Dropout(p=dropout),
                nn.Linear(num_ftrs, num_out_classes)
            )

    def forward(self, x):
        x = self.model(x)
        return x

In [None]:
xception_default = {
    'train': transforms.Compose([transforms.ToTensor(),
                                 transforms.Resize((224, 224)),
                                 transforms.RandomHorizontalFlip(),
                                 transforms.Normalize([0.5]*3, [0.5]*3),
                                 ]),
    'valid': transforms.Compose([transforms.ToTensor(),
                                 transforms.Resize((224, 224)),
                                 transforms.Normalize([0.5]*3, [0.5]*3),
                                 ]),
    'test': transforms.Compose([transforms.ToTensor(),
                                transforms.Resize((224, 224)),
                                transforms.Normalize([0.5] * 3, [0.5] * 3),
                                ]),
}

In [None]:
model = xception(num_out_classes=2, dropout=0.5).cuda()
print("=> creating model '{}'".format('xception'))

assert os.path.isfile(args.save_fn), 'wrong path'

model.load_state_dict(torch.load(args.save_fn)['state_dict'])
print("=> model weight '{}' is loaded".format(args.save_fn))

model = model.eval()

In [None]:
transform = xception_default['test']

In [None]:
predictions = []

for index, row in tqdm(val_data.iterrows(), total=len(val_data), desc="Processing Videos"):
    video_path = row['path']
    video_file = os.path.basename(video_path)
    
    if video_file.endswith('.mp4'):
        cap = cv2.VideoCapture(video_path)
        total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
        images_per_video = total_frames // 10

        fake_count, real_count = 0, 0
        for i in range(images_per_video):
            frames_for_prediction = []
            for _ in range(10):
                ret, frame = cap.read()
                if ret:
                    frames_for_prediction.append(frame)

            if len(frames_for_prediction) == 10:
                # 이미지를 Xception 모델에 입력 가능한 형태로 변환
                transformed_frames = [Image.fromarray(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)) for frame in frames_for_prediction]
                preprocess = transform
                transformed_frames = [preprocess(frame) for frame in transformed_frames]
                transformed_frames = torch.stack(transformed_frames).cuda()

                # Xception 모델 예측
                with torch.no_grad():
                    outputs = model(transformed_frames)
                    _, predicted = torch.max(outputs, 1)

                    # fake와 real 카운트 증가
                    fake_count += (predicted == 0).sum().item()  # 가정: 0이 fake 클래스, 1이 real 클래스입니다.
                    real_count += (predicted == 1).sum().item()

        cap.release()
        
        # 각 비디오의 fake와 real 이미지 수 비교
        if fake_count > real_count:
            predictions.append({'video': video_file, 'prediction': 'fake'})
        else:
            predictions.append({'video': video_file, 'prediction': 'real'})

# 예측 결과 출력
for pred in predictions:
    print(f"Video: {pred['video']}, Prediction: {pred['prediction']}")