# Primer Avance de Código — Tesis de Detección y Segmentación de Deepfakes

In [None]:
!pip install timm torchvision matplotlib opencv-python

## Importar librerías

In [None]:

import torch
import torch.nn as nn
import torchvision.transforms as T
import timm
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
    

## Definir arquitectura híbrida Vision Transformer + U-Net

In [None]:

class ViT_UNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.vit = timm.create_model('vit_base_patch16_224', pretrained=True, features_only=True)
        vit_channels = self.vit.feature_info.channels()[-1]

        self.decoder = nn.Sequential(
            nn.Conv2d(vit_channels, 256, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(256, 128, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Upsample(scale_factor=2),
            nn.Conv2d(128, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 1, kernel_size=1),
            nn.Sigmoid()
        )

    def forward(self, x):
        features = self.vit(x)[-1]
        mask = self.decoder(features)
        return mask
    

## Crear dataset sintético

In [None]:

class FakeFaceDataset(torch.utils.data.Dataset):
    def __init__(self, n=200, img_size=224):
        self.n = n
        self.img_size = img_size
        self.transform = T.Compose([
            T.ToTensor(),
            T.Normalize(0.5, 0.5)
        ])

    def __len__(self):
        return self.n

    def __getitem__(self, idx):
        img = Image.new('RGB', (self.img_size, self.img_size), color='gray')
        draw = ImageDraw.Draw(img)
        x0, y0 = np.random.randint(50, 150, size=2)
        r = np.random.randint(20, 50)
        draw.ellipse((x0, y0, x0+r, y0+r), fill='white')
        mask = Image.new('L', (self.img_size, self.img_size), color=0)
        draw2 = ImageDraw.Draw(mask)
        draw2.ellipse((x0, y0, x0+r, y0+r), fill=1)
        img_t = self.transform(img)
        mask_t = torch.from_numpy(np.array(mask)).unsqueeze(0).float()
        return img_t, mask_t

train_ds = FakeFaceDataset(n=100)
train_loader = torch.utils.data.DataLoader(train_ds, batch_size=4, shuffle=True)
    

## Entrenar modelo por 3 épocas

In [None]:

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ViT_UNet().to(device)
criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model.train()
for epoch in range(3):
    total_loss = 0
    for imgs, masks in train_loader:
        imgs, masks = imgs.to(device), masks.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, masks)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print(f"Epoch {epoch+1}/3 — Loss: {total_loss/len(train_loader):.4f}")
    

## Visualizar imágenes, máscaras reales y predicciones

In [None]:

model.eval()
imgs, masks = next(iter(train_loader))
imgs, masks = imgs.to(device), masks.to(device)

with torch.no_grad():
    preds = model(imgs)

def show(im, title=''):
    arr = im.squeeze().cpu().numpy()
    plt.imshow(arr, cmap='gray')
    plt.title(title)
    plt.axis('off')

plt.figure(figsize=(8,8))
for i in range(4):
    plt.subplot(4,3,3*i+1); show(imgs[i], 'Imagen')
    plt.subplot(4,3,3*i+2); show(masks[i], 'Máscara real')
    plt.subplot(4,3,3*i+3); show(preds[i], 'Predicción')
plt.tight_layout()
plt.show()
    