In [1]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split, Subset
from torchvision import datasets
from torchvision.models import resnet50
from SSP.networks.ssp import ssp
from torchvision.models.feature_extraction import create_feature_extractor
from tqdm import tqdm
import clip
from models import TransformerClassifier
import matplotlib.pyplot as plt
import pandas as pd
import os

# Set up

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [6]:
_datasets = ['adm', 'biggan', 'glide', 'midjourney', 'sdv5', 'vqdm', 'wukong']

dataset = 'all'
os.makedirs(f'./weights/{dataset}', exist_ok=True)
os.makedirs(f'./results/{dataset}', exist_ok=True)

In [7]:
clip_model, preprocess = clip.load("ViT-B/32", device=device)

clip_model.float()

student_model = nn.TransformerEncoder(nn.TransformerEncoderLayer(d_model=512, nhead=4, dropout=0.2, batch_first=True), num_layers=1).to(device).eval()
student_model.load_state_dict(torch.load(f"./weights/{dataset}/student.pth", weights_only=True))

teacher_model = TransformerClassifier().to(device).eval()
teacher_model.load_state_dict(torch.load(f"./weights/{dataset}/teacher.pth", weights_only=True))
teacher_model = create_feature_extractor(teacher_model, return_nodes={"transformer_encoder": "layer1", "fc": "layer2"})

classifier = TransformerClassifier().to(device).eval()
classifier.load_state_dict(torch.load(f"./weights/{dataset}/classifier.pth", weights_only=True))

# resnet = resnet50(num_classes=2).to(device)
# resnet.load_state_dict(torch.load(f"./weights/{dataset}/resnet50.pth", weights_only=True))

# _ssp = ssp().to(device)
# _ssp.load_state_dict(torch.load(f"./weights/{dataset}/ssp.pth", weights_only=True))

<All keys matched successfully>

# Classifier

In [9]:
result = []

for _ in _datasets:
    test_path = f"./data/{_}/val"
    test_folder = datasets.ImageFolder(root=test_path, transform=preprocess)
    test_loader = DataLoader(test_folder, batch_size=32, shuffle=False)
    
    with tqdm(total=len(test_loader), desc=f'Test {_}', unit='batch') as pbar:
        with torch.inference_mode():
            correct = 0
            for batch, (images, labels) in enumerate(test_loader):
                images, labels = images.to(device), labels.to(device)
                image_features = clip_model.encode_image(images)
                student_output = student_model(image_features)
                teacher_output = teacher_model(image_features)['layer1']
                logits = classifier(torch.pow(student_output - teacher_output, 2))
                
                predictions = torch.argmax(logits, dim=-1)
                correct += (predictions == labels).sum().item()
                
                pbar.update()
                
        acc = correct/len(test_loader.dataset)*100
        pbar.set_postfix({'acc': acc})
                
    result.append({'dataset': _, 'acc': acc})

Test adm: 100%|██████████| 32/32 [00:09<00:00,  3.38batch/s, acc=54.9]
Test biggan: 100%|██████████| 32/32 [00:08<00:00,  3.96batch/s, acc=59.6]
Test glide: 100%|██████████| 32/32 [00:08<00:00,  3.90batch/s, acc=60.8]
Test midjourney: 100%|██████████| 32/32 [00:19<00:00,  1.66batch/s, acc=51.5]
Test sdv5: 100%|██████████| 32/32 [00:11<00:00,  2.68batch/s, acc=49.7]
Test vqdm: 100%|██████████| 32/32 [00:08<00:00,  3.60batch/s, acc=49.9]
Test wukong: 100%|██████████| 32/32 [00:11<00:00,  2.68batch/s, acc=49.9]


In [10]:
result = pd.DataFrame(result)
result.to_csv(f'./results/{dataset}/classifier_test.csv', index=False)

print(f'Avg acc: {result['acc'].mean()}')
result

Avg acc: 53.75714285714285


Unnamed: 0,dataset,acc
0,adm,54.9
1,biggan,59.6
2,glide,60.8
3,midjourney,51.5
4,sdv5,49.7
5,vqdm,49.9
6,wukong,49.9


# Teacher

In [11]:
result = []

for _ in _datasets:
    test_path = f"./data/{_}/val"
    test_folder = datasets.ImageFolder(root=test_path, transform=preprocess)
    test_loader = DataLoader(test_folder, batch_size=32, shuffle=False)
    
    with tqdm(total=len(test_loader), desc=f'Test {_}', unit='batch') as pbar:
        with torch.inference_mode():
            correct = 0
            for batch, (images, labels) in enumerate(test_loader):
                images, labels = images.to(device), labels.to(device)
                image_features = clip_model.encode_image(images)
                logits = teacher_model(image_features)['layer2']
                
                predictions = torch.argmax(logits, dim=-1)
                correct += (predictions == labels).sum().item()
                
                pbar.update()
                
        acc = correct/len(test_loader.dataset)*100
        pbar.set_postfix({'acc': acc})
                
    result.append({'dataset': _, 'acc': acc})

Test adm: 100%|██████████| 32/32 [00:08<00:00,  3.59batch/s, acc=98.8]
Test biggan: 100%|██████████| 32/32 [00:08<00:00,  3.94batch/s, acc=59.7]
Test glide: 100%|██████████| 32/32 [00:08<00:00,  3.92batch/s, acc=59.1]
Test midjourney: 100%|██████████| 32/32 [00:20<00:00,  1.57batch/s, acc=54.2]
Test sdv5: 100%|██████████| 32/32 [00:11<00:00,  2.74batch/s, acc=49.4]
Test vqdm: 100%|██████████| 32/32 [00:08<00:00,  3.66batch/s, acc=52]
Test wukong: 100%|██████████| 32/32 [00:11<00:00,  2.79batch/s, acc=50.1]


In [12]:
result = pd.DataFrame(result)
result.to_csv(f'./results/{dataset}/teacher_test.csv', index=False)

print(f'Avg acc: {result['acc'].mean()}')
result

Avg acc: 60.471428571428575


Unnamed: 0,dataset,acc
0,adm,98.8
1,biggan,59.7
2,glide,59.1
3,midjourney,54.2
4,sdv5,49.4
5,vqdm,52.0
6,wukong,50.1


# Resnet

In [None]:
# result = []

# for _ in _datasets:
#     test_path = f"./data/{_}/val"
#     test_folder = datasets.ImageFolder(root=test_path, transform=preprocess)
#     test_loader = DataLoader(test_folder, batch_size=32, shuffle=False)
    
#     with tqdm(total=len(test_loader), desc=f'Test {_}', unit='batch') as pbar:
#         with torch.inference_mode():
#             correct = 0
#             for batch, (images, labels) in enumerate(test_loader):
#                 images, labels = images.to(device), labels.to(device)
#                 logits = resnet(images)
                
#                 predictions = torch.argmax(logits, dim=-1)
#                 correct += (predictions == labels).sum().item()
                
#                 pbar.update()
                
#         acc = correct/len(test_loader.dataset)*100
#         pbar.set_postfix({'acc': acc})
                
#     result.append({'dataset': _, 'acc': acc})

Test adm: 100%|██████████| 32/32 [00:09<00:00,  3.48batch/s, acc=47.1]
Test biggan: 100%|██████████| 32/32 [00:07<00:00,  4.03batch/s, acc=50.7]
Test glide: 100%|██████████| 32/32 [00:08<00:00,  3.85batch/s, acc=47.8]
Test midjourney: 100%|██████████| 32/32 [00:19<00:00,  1.60batch/s, acc=51.2]
Test sdv5: 100%|██████████| 32/32 [00:10<00:00,  2.95batch/s, acc=49.6]
Test vqdm: 100%|██████████| 32/32 [00:08<00:00,  3.93batch/s, acc=47.8]
Test wukong: 100%|██████████| 32/32 [00:10<00:00,  2.97batch/s, acc=49.7]


In [None]:
# result = pd.DataFrame(result)
# result.to_csv(f'./results/{dataset}/resnet50_test.csv', index=False)

# print(f'Avg acc: {result['acc'].mean()}')
# result

Avg acc: 49.128571428571426


Unnamed: 0,dataset,acc
0,adm,47.1
1,biggan,50.7
2,glide,47.8
3,midjourney,51.2
4,sdv5,49.6
5,vqdm,47.8
6,wukong,49.7


# SSP

In [None]:
# result = []

# for _ in _datasets:
#     test_path = f"./data/{_}/val"
#     test_folder = datasets.ImageFolder(root=test_path, transform=preprocess)
#     test_loader = DataLoader(test_folder, batch_size=32, shuffle=False)
    
#     with tqdm(total=len(test_loader), desc=f'Test {_}', unit='batch') as pbar:
#         with torch.inference_mode():
#             correct = 0
#             for batch, (images, labels) in enumerate(test_loader):
#                 images, labels = images.to(device), labels.to(device)
#                 preds = _ssp(images).ravel()
                
#                 predictions = torch.sigmoid(preds) > 0.5
#                 correct += (predictions == labels).sum().item()
                
#                 pbar.update()
                
#         acc = correct/len(test_loader.dataset)*100
#         pbar.set_postfix({'acc': acc})
                
#     result.append({'dataset': _, 'acc': acc})

Test adm: 100%|██████████| 32/32 [00:10<00:00,  3.13batch/s, acc=51.2]
Test biggan: 100%|██████████| 32/32 [00:09<00:00,  3.41batch/s, acc=50.5]
Test glide: 100%|██████████| 32/32 [00:11<00:00,  2.74batch/s, acc=50]
Test midjourney: 100%|██████████| 32/32 [00:24<00:00,  1.33batch/s, acc=49.7]
Test sdv5: 100%|██████████| 32/32 [00:12<00:00,  2.47batch/s, acc=52.6]
Test vqdm: 100%|██████████| 32/32 [00:11<00:00,  2.80batch/s, acc=51.2]
Test wukong: 100%|██████████| 32/32 [00:15<00:00,  2.10batch/s, acc=52.7]


In [None]:
# result = pd.DataFrame(result)
# result.to_csv(f'./results/{dataset}/ssp_test.csv', index=False)

# print(f'Avg acc: {result['acc'].mean()}')
# result

Avg acc: 51.128571428571426


Unnamed: 0,dataset,acc
0,adm,51.2
1,biggan,50.5
2,glide,50.0
3,midjourney,49.7
4,sdv5,52.6
5,vqdm,51.2
6,wukong,52.7


# Random

In [8]:
from PIL import Image

def classifier_itest(image_path, device='cuda'):
    classes = [
        "AI",
        "Nature"
    ]

    input_image = Image.open(image_path)
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0).to(device)
    input_batch = clip_model.encode_image(input_batch)

    with torch.inference_mode():
        output = classifier(input_batch)
    probs = torch.nn.functional.softmax(output[0], dim=0)
    for idx, prob in enumerate(probs):
        print(f'{classes[idx]}: {prob*100:.2f}%')

    print(f'\nPrediction: {classes[torch.argmax(probs)]}')
    
def teacher_itest(image_path, device='cuda'):
    classes = [
        "AI",
        "Nature"
    ]

    input_image = Image.open(image_path)
    input_tensor = preprocess(input_image)
    input_batch = input_tensor.unsqueeze(0).to(device)
    input_batch = clip_model.encode_image(input_batch)

    with torch.inference_mode():
        output = teacher_model(input_batch)['layer2']
    probs = torch.nn.functional.softmax(output[0], dim=0)
    for idx, prob in enumerate(probs):
        print(f'{classes[idx]}: {prob*100:.2f}%')

    print(f'\nPrediction: {classes[torch.argmax(probs)]}')

In [25]:
img_path = "./data/sdv5/val/nature/ILSVRC2012_val_00002370.JPEG"

In [26]:
teacher_itest(img_path, device=device)

AI: 0.45%
Nature: 99.55%

Prediction: Nature


In [27]:
classifier_itest(img_path, device=device)

AI: 50.45%
Nature: 49.55%

Prediction: AI
