In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.data.dataset import random_split
from torchtext.datasets import AG_NEWS
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import Vocab

from collections import Counter

from IPython.core.display import HTML, display
from models.generators import DiffusionModel
from models.discriminators import DiscriminatorNetCIFAR10
from torchvision import transforms
import torch.nn.functional as F
from captum.attr import visualization as viz
from captum.attr import Lime, LimeBase
from captum._utils.models.linear_model import SkLearnLinearRegression, SkLearnLasso
import os
import json

In [None]:
class model():
    def __init__(self,pretrained_path):
        
        self.encoder = DiffusionModel()
        #self.decoder = DiscriminatorNetCIFAR10()
        self.decoder = DiscriminatorNetCIFAR10()
        #discriminator = discriminator.load_state_dict(torch.load(pretrained_path, map_location=lambda storage, loc: storage)['model_state_dict'])
        self.decoder.load_state_dict(torch.load(pretrained_path, map_location=lambda storage, loc: storage)['model_state_dict'])
        self.transform = transforms.Compose([
            transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),transforms.Resize((32,32))])
    
    def forward(self, x):
        label = self.encoder.forward(x)
        
        label = self.transform(label)
        label = label.unsqueeze(0)
        #print(x.shape)
        
        label = self.decoder(label)
        return label
    
    def encode_part(self,input):
        return self.encoder.forward(input)

In [None]:
pretrained_path = r"D:\Praktikum\xai-diffusion\xai-praktikum\xaigan\src\results\cifar-10\CIFAR10_only_SaliencyTrain\discriminator.pt"
model = model(pretrained_path)
text_inputs = ["A glass of beer is sitting next to a vase full of flowers."]
label = model.forward(text_inputs)
print(label)

In [None]:
from lime import lime_text
from lime.lime_text import LimeTextExplainer

class_names = ["false","real"]
text = "A glass of beer is sitting next to a vase full of flowers."

explainer = LimeTextExplainer(class_names=class_names)
exp = explainer.explain_instance(text, model.forward, num_features=3, labels=[0,1])
#print('Document id: %d' % idx)
#print(text)
#print('Predicted class =', class_names[model.forward(text).reshape(1,-1)[0,0]])
#print('True class: %s' % true_label)
exp.as_list()
exp.show_in_notebook(text=text, labels=(1,))