In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = ""
print("CUDA_VISIBLE_DEVICES =", os.environ["CUDA_VISIBLE_DEVICES"])

!pip install timm pillow seaborn python-docx --quiet

from google.colab import drive
drive.mount('/content/drive')

import torch, timm, numpy as np
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report
from docx import Document
from docx.shared import Inches
import os, random

device = torch.device('cpu')
print("Running on device:", device)

DATA_ROOT  = '/content/drive/MyDrive/dp/data/my/data'
OUTPUT_DIR = '/content/drive/MyDrive/dp/data/my/output'
os.makedirs(OUTPUT_DIR, exist_ok=True)

BATCH_SIZE = 16
CM_PNG = os.path.join(OUTPUT_DIR, 'swin_confusion_matrix.png')
CLF_TXT = os.path.join(OUTPUT_DIR, 'swin_classification_report.txt')

class Transform:
    def __init__(self):
        self.mean = np.array([0.485, 0.456, 0.406], dtype=np.float32)
        self.std  = np.array([0.229, 0.224, 0.225], dtype=np.float32)
    def __call__(self, img: Image.Image):
        img = img.resize((224,224))
        arr = np.array(img.convert('RGB'), dtype=np.float32)/255.0
        arr = (arr - self.mean)/self.std
        return torch.from_numpy(arr).permute(2,0,1)

transform = Transform()

class AnnotDataset(Dataset):
    def __init__(self, root, transform):
        self.samples = []
        self.classes = sorted(d for d in os.listdir(root) if os.path.isdir(os.path.join(root,d)))
        self.cl2i    = {c:i for i,c in enumerate(self.classes)}
        for c in self.classes:
            for fn in os.listdir(os.path.join(root,c)):
                if fn.lower().endswith(('.jpg','.jpeg','.png')):
                    self.samples.append((os.path.join(root,c,fn), self.cl2i[c]))
        self.transform = transform
    def __len__(self): return len(self.samples)
    def __getitem__(self, idx):
        p,l = self.samples[idx]
        img = Image.open(p)
        return self.transform(img), l

dataset = AnnotDataset(DATA_ROOT, transform)
NUM_CLASSES = len(dataset.classes)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=False)

model = timm.create_model('swin_base_patch4_window7_224', pretrained=True)
in_features = model.head.in_features
model.head = torch.nn.Linear(in_features, NUM_CLASSES)
model.eval().to(device)

print(" Loaded pretrained Swin Transformer ")

all_y, all_p = [], []
with torch.no_grad():
    for x,y in loader:
        x = x.to(device)
        logits = model(x)
        preds = logits.argmax(1).cpu().numpy()
        all_p.extend(preds)
        all_y.extend(y.numpy())

cm = confusion_matrix(all_y, all_p)
plt.figure(figsize=(6,5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=dataset.classes, yticklabels=dataset.classes)
plt.savefig(CM_PNG); plt.close()

rep = classification_report(all_y, all_p, target_names=dataset.classes)
with open(CLF_TXT,'w') as f: f.write(rep)

doc = Document()
doc.add_heading('Swin Transformer Inference Report', level=1)
doc.add_heading('Confusion Matrix', level=2)
doc.add_picture(CM_PNG, width=Inches(5))
doc.add_heading('Classification Report', level=2)
doc.add_paragraph(rep)
doc.save(os.path.join(OUTPUT_DIR, 'swin_inference_report.docx'))

print("✅ Inference complete — results saved in:", OUTPUT_DIR)
