In [36]:
from torchtext import data, datasets
import torch
import torch.nn as nn
import torch.optim as optim
import random
import numpy as np
from tqdm.notebook import tqdm
from torch.utils.data import DataLoader

In [2]:
class IMDbDataset(torch.utils.data.Dataset):
    def __init__(self, split='train', max_vocab_size=20000, embedding_dim=100):

        RANDOM_SEED = 0

        self.TEXT = data.Field(tokenize='spacy', include_lengths=True)
        self.LABEL = data.LabelField(dtype=torch.float)

        train_data, test_data = datasets.IMDB.splits(self.TEXT, self.LABEL)
        # train_data, valid_data = train_data.split(random_state=random.seed(RANDOM_SEED), 
        #                                           split_ratio=0.8)

        self.TEXT.build_vocab(train_data, max_size=max_vocab_size, vectors=f"glove.6B.{embedding_dim}d")
        self.LABEL.build_vocab(train_data)

        if split == 'train':
            self.data = train_data
        elif split == 'valid':
            self.data = valid_data
        elif split == 'test':
            self.data = test_data
        else:
            raise ValueError("Invalid split. Use 'train', 'valid', or 'test'.")

        self.fields = {'text': self.TEXT, 'label': self.LABEL}

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        example = self.data[idx]
        text = example.text
        label = example.label
        
        return data.Example.fromlist([text, label], fields=[('text', self.TEXT), ('label', self.LABEL)])

In [3]:
class IMDbDataLoader:
    def __init__(self, dataset, batch_size=64):
        self.batch_size = batch_size
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

        # Ensure `sort_key` is set based on the length of the text sequence
        sort_key = lambda x: len(x.text)

        # Create BucketIterator with the correct sort_key
        self.iterator = data.BucketIterator(
            dataset=dataset,
            batch_size=self.batch_size,
            sort_key=sort_key,  # Add sort_key here
            sort_within_batch=True,
            device=self.device
        )

    def __iter__(self):
        return iter(self.iterator)

    def __len__(self):
        return len(self.iterator)

In [4]:
imdb_train_dataset = IMDbDataset(split='train')
train_loader = IMDbDataLoader(imdb_train_dataset)
for x in train_loader:
    text, text_len = x.text
    print(text.shape)
    print(text_len.shape)
    print(x.label.shape)
    break

torch.Size([240, 64])
torch.Size([64])
torch.Size([64])


In [18]:
train_data = imdb_train_dataset
subset_fraction = 0.2
subset_size = int(len(train_data) * subset_fraction)
rand_idxs = np.random.choice(range(len(train_data)), subset_size)

In [3]:
from data_proc.dataset import TextDatasetSubset

In [4]:
ori_size = 20000
subset_size = int(ori_size * 0.2)
rand_idxs = np.random.choice(range(ori_size), subset_size)

In [5]:
subset = TextDatasetSubset(rand_idxs)

In [6]:
subset

<data_proc.dataset.TextDatasetSubset at 0x7faed9070340>

In [2]:
train_data = IMDbDataset(split='train')

NameError: name 'IMDbDataset' is not defined

In [19]:
TEXT = data.Field(tokenize='spacy', include_lengths=True)
LABEL = data.LabelField(dtype=torch.float)

idxs = rand_idxs
new_data = []
for i in idxs:
    example = train_data[i]
    new_example = data.Example.fromlist([example.text, example.label], [('text', TEXT), ('label', LABEL)])
    new_data.append(example)
train_data.examples = new_data


TEXT.build_vocab(train_data, max_size=20000, vectors=f"glove.6B.100d")
LABEL.build_vocab(train_data)

TypeError: 'Example' object is not iterable

In [7]:
subset = TextDatasetSubset(train_data, rand_idxs)

TypeError: 'Example' object is not iterable

In [12]:
new_data = []
for i in rand_idxs:
    new_data.append(train_data[i])

<__main__.IMDbDataset at 0x7ffaf621c430>

In [83]:
def random_deletion(tokens, p=0.5):
    if len(tokens) == 0:
        return tokens

    mask = np.random.rand(len(tokens)) > p
    remaining_tokens = list(np.array(tokens)[mask])
    # remaining_tokens = [token for token in tokens if random.uniform(0, 1) > p]
    if len(remaining_tokens) == 0:
        return [random.choice(tokens)]  # 如果全部删除，则随机保留一个
    return remaining_tokens

class AugmentedIMDbDataset(IMDbDataset):
    def __init__(self, split='train', max_vocab_size=20000, embedding_dim=100, augment_function = None, num_positive=2):
        super().__init__(split, max_vocab_size, embedding_dim)
        self.num_positive = num_positive
        self.augment_function = augment_function

        augmented_examples = []
        for example in self.data.examples:
            for _ in range(self.num_positive):
                example_augmented_text = self.augment_function(example.text)
                new_example = data.Example.fromlist(
                    [example_augmented_text, example.label], 
                    [('text', self.TEXT), ('label', self.LABEL)]
                )
                augmented_examples.append(new_example)
        self.data.examples = augmented_examples

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return super().__getitem__(idx)

In [84]:
augmented_imdb_train_dataset = AugmentedIMDbDataset(split='train',
                                                    augment_function=random_deletion,
                                                    num_positive=2)
augmented_imdb_train_loader = IMDbDataLoader(augmented_imdb_train_dataset, batch_size = 128)

In [86]:
# Iterate through the loader to check the augmented samples
for batch in augmented_imdb_train_loader:
    text, text_length = batch.text
    label = batch.label
    print(text.shape)
    print(text_length.shape)
    print(label.shape)
    break

torch.Size([516, 128])
torch.Size([128])
torch.Size([128])


## CLIP APPROX

In [27]:
from torchtext.vocab import GloVe

In [28]:
class IMDbGloveEmbeddedDataset(torch.utils.data.Dataset):
    def __init__(self, split='train', max_vocab_size=20000, embedding_dim=100):

        RANDOM_SEED = 0

        self.TEXT = data.Field(tokenize='spacy', include_lengths=True)
        self.LABEL = data.LabelField(dtype=torch.float)

        train_data, test_data = datasets.IMDB.splits(self.TEXT, self.LABEL)
        train_data, valid_data = train_data.split(random_state=random.seed(RANDOM_SEED), 
                                                  split_ratio=0.8)

        self.TEXT.build_vocab(train_data, max_size=max_vocab_size, vectors=f"glove.6B.{embedding_dim}d")
        self.LABEL.build_vocab(train_data)

        if split == 'train':
            self.data = train_data
        elif split == 'valid':
            self.data = valid_data
        elif split == 'test':
            self.data = test_data
        else:
            raise ValueError("Invalid split. Use 'train', 'valid', or 'test'.")

        glove = GloVe(name='6B', dim=embedding_dim)
        
        texts = []
        labels = []
        
        for sample in self.data:
            text = sample.text
            label = (1 if sample.label == 'pos' else 0)
            
            word_vectors = [glove[word.lower()] for word in text if word.lower() in glove.stoi]
            texts.append(torch.stack(word_vectors).mean(0))
            labels.append(label)
        self.X = torch.stack(texts)
        self.y = torch.Tensor(labels)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return (self.X[idx], self.y[idx])

In [30]:
from torch.utils.data import DataLoader

In [31]:
GloveEmbeddedDataset = IMDbGloveEmbeddedDataset('train')

In [32]:
dl = DataLoader(GloveEmbeddedDataset, batch_size=64)

In [40]:
for x in dl:
    print(x[0].shape)
    print(x[1].shape)
    break

torch.Size([64, 100])
torch.Size([64])


In [37]:
class TwoLayerLinearModel(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(TwoLayerLinearModel, self).__init__()
        self.layer1 = nn.Linear(input_size, hidden_size)
        self.relu = nn.ReLU()
        self.layer2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.layer1(x)
        x = self.relu(x)
        x = self.layer2(x)
        return x

In [45]:
print(inputs.shape)

AttributeError: 'tuple' object has no attribute 'shape'

In [50]:
model = TwoLayerLinearModel(input_size=100, hidden_size=128, output_size=1)
criterion = nn.BCEWithLogitsLoss() 
optimizer = optim.SGD(model.parameters(), lr=0.1)

# Training loop
num_epochs = 10

for epoch in range(num_epochs):
    total_loss = 0.0
    correct_predictions = 0

    for inputs, targets in dl:
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs)

        # Compute the loss
        loss = criterion(outputs.squeeze(), targets)  # Squeeze to remove unnecessary dimensions

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

        # Compute the number of correct predictions
        predictions = torch.round(torch.sigmoid(outputs))  # Convert logits to binary predictions
        correct_predictions += (predictions == targets).sum().item()

    # Print average loss and training accuracy for the epoch
    average_loss = total_loss / len(GloveEmbeddedDataset)
    accuracy = correct_predictions / len(GloveEmbeddedDataset)
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {average_loss:.4f}, Training Accuracy: {accuracy:.4f}")

Epoch [1/10], Loss: 0.0107, Training Accuracy: 32.0109
Epoch [2/10], Loss: 0.0102, Training Accuracy: 32.0372
Epoch [3/10], Loss: 0.0097, Training Accuracy: 32.0855
Epoch [4/10], Loss: 0.0095, Training Accuracy: 32.1172
Epoch [5/10], Loss: 0.0093, Training Accuracy: 32.1241
Epoch [6/10], Loss: 0.0091, Training Accuracy: 32.1240
Epoch [7/10], Loss: 0.0089, Training Accuracy: 32.1322
Epoch [8/10], Loss: 0.0088, Training Accuracy: 32.1275
Epoch [9/10], Loss: 0.0087, Training Accuracy: 32.1417
Epoch [10/10], Loss: 0.0085, Training Accuracy: 32.1343


In [90]:
from typing import List

In [92]:
def encode_using_glove(dataset, device):
    glove = GloVe(name='6B', dim=100)
    texts = []
    
    for i in range(len(dataset)):
        sample = dataset[i]
        word_vectors = [glove[word.lower()] for word in sample.text if word.lower() in glove.stoi]
        texts.append(torch.stack(word_vectors).mean(0))
        
    Z = torch.stack(texts).to(device)
    return Z

In [67]:
def train_linear_classifier(
    X: torch.tensor, 
    y: torch.tensor, 
    representation_dim: int,
    num_classes: int,
    device: torch.device,
    reg_weight: float = 1e-3,
    n_lbfgs_steps: int = 500,
    verbose=False,
):
    if verbose:
        print('\nL2 Regularization weight: %g' % reg_weight)

    criterion = nn.CrossEntropyLoss()
    X_gpu = X.to(device)
    y_gpu = y.to(device)

    # Should be reset after each epoch for a completely independent evaluation
    clf = nn.Linear(representation_dim, num_classes).to(device)
    clf_optimizer = optim.LBFGS(clf.parameters())
    clf.train()

    for _ in tqdm(range(n_lbfgs_steps), desc="Training linear classifier using fraction of labels", disable=not verbose):
        def closure():
            clf_optimizer.zero_grad()
            raw_scores = clf(X_gpu)
            loss = criterion(raw_scores, y_gpu)
            loss += reg_weight * clf.weight.pow(2).sum()
            loss.backward()
            return loss
        clf_optimizer.step(closure)
    return clf

In [77]:
def partition_from_preds(preds):
    partition = {}
    for i, pred in enumerate(preds):
        if pred not in partition:
            partition[pred] = []
        partition[pred].append(i)
    return partition

In [93]:
def glove_approx(
    trainset: torch.utils.data.Dataset,
    labeled_example_indices: List[int], 
    labeled_examples_labels: np.array,
    num_classes: int,
    device: torch.device, 
    batch_size: int = 512,
    verbose: bool = False,
):
    Z = encode_using_glove(trainset, device)
    clf = train_linear_classifier(
        X=Z[labeled_example_indices], 
        y=torch.tensor(labeled_examples_labels), 
        representation_dim=len(Z[0]),
        num_classes=num_classes,
        device=device,
        verbose=False
    )
    preds = []
    for start_idx in range(0, len(Z), batch_size):
        preds.append(torch.argmax(clf(Z[start_idx:start_idx + batch_size]).detach(), dim=1).cpu())
    preds = torch.cat(preds).numpy()

    return partition_from_preds(preds)
    

In [94]:
rand_labeled_examples_indices = random.sample(range(len(imdb_train_dataset)), 500)
rand_labeled_examples_labels = [
    1 if imdb_train_dataset[i].label == 'pos' else 0 for i in rand_labeled_examples_indices
]

In [57]:
imdb_train_dataset = IMDbDataset(split='train')

In [96]:
partition = glove_approx(
    trainset=imdb_train_dataset,
    labeled_example_indices=rand_labeled_examples_indices, 
    labeled_examples_labels=rand_labeled_examples_labels,
    num_classes=2,
    device=device
)

In [26]:
import torch
import torch.nn as nn


class LSTM(nn.Module):
    def __init__(self, input_dim, embedding_dim=128, hidden_dim=256, pre_embedding=None, output_dim=None):

        super().__init__()

        self.embedding = nn.Embedding(input_dim, embedding_dim)

        if pre_embedding is not None:
            self.embedding.weight.data.copy_(pre_embedding)

        self.rnn = nn.LSTM(embedding_dim, hidden_dim, num_layers=3)

        self.representation_dim = hidden_dim

        if output_dim is not None:
            self.fc = nn.Linear(hidden_dim, output_dim)
        else:
            self.fc = None

    def forward(self, text, text_length):

        # [sentence len, batch size] => [sentence len, batch size, embedding size]
        embedded = self.embedding(text)

        packed = torch.nn.utils.rnn.pack_padded_sequence(embedded, text_length.cpu()).cuda()

        # [sentence len, batch size, embedding size] =>
        #  output: [sentence len, batch size, hidden size]
        #  hidden: [1, batch size, hidden size]
        packed_output, (hidden, cell) = self.rnn(packed)

        if self.fc is not None:
            return self.fc(hidden.squeeze(0)).view(-1)
        else:
            return hidden[-1]


In [19]:
trainset = imdb_train_dataset
net = LSTM(
            input_dim=len(trainset.TEXT.vocab),
            embedding_dim=100,
            pre_embedding=trainset.TEXT.vocab.vectors
        )

In [20]:
device = torch.device('cuda')
net = net.to(device)

In [21]:
text.shape

torch.Size([240, 64])

In [23]:
x = net(text, text_len)

In [25]:
x[-1].shape

torch.Size([64, 256])

## SAS

In [97]:
from abc import ABC
from typing import Dict, List, Optional
import math 
import pickle
import random 

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset

from sas.submodular_maximization import lazy_greedy
from tqdm import tqdm

In [123]:
from torch import nn 

class ProxyModel(nn.Module):
    def __init__(self, net, critic):
        super().__init__()
        self.net = net
        self.critic = critic
    def forward(self, text, text_lengths):
        return self.critic.project(self.net(text, text_lengths))

In [124]:
class BaseSubsetDataset(ABC, Dataset):
    def __init__(
        self,
        dataset: Dataset,
        subset_fraction: float,
        verbose: bool = False
    ):
        """
        :param dataset: Original Dataset
        :type dataset: Dataset
        :param subset_fraction: Fractional size of subset
        :type subset_fraction: float
        :param verbose: verbose
        :type verbose: boolean
        """
        self.dataset = dataset
        self.subset_fraction = subset_fraction
        self.len_dataset = len(self.dataset)
        self.subset_size = int(self.len_dataset * self.subset_fraction)
        self.subset_indices = None
        self.verbose = verbose 

    def initialization_complete(self):
        if self.verbose:
            print(f"Subset Size: {self.subset_size}")
            print(f"Discarded {self.len_dataset - self.subset_size} examples")

    def __len__(self):
        return self.subset_size
    
    def __getitem__(self, index):
        # Get the index for the corresponding item in the original dataset
        original_index = self.subset_indices[index]
        
        # Get the item from the original dataset at the corresponding index
        original_item = self.dataset[original_index]
        
        return original_item
    
    def save_to_file(self, filename):
        with open(filename, "wb") as f:
            pickle.dump(self.subset_indices, f)

In [125]:

from typing import Dict, Optional

import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import Dataset

from sas.submodular_maximization import lazy_greedy
from sas.subset_dataset import BaseSubsetDataset, SubsetSelectionObjective
from tqdm import tqdm

from data_proc.NLPDataLoader import IMDbDataLoader


class SASSubsetTextDataset(BaseSubsetDataset):
    def __init__(
            self,
            dataset: Dataset,
            subset_fraction: float,
            num_downstream_classes: int,
            device: torch.device,
            approx_latent_class_partition: Dict[int, int],
            proxy_model: Optional[nn.Module] = None,
            augmentation_distance: Optional[Dict[int, np.array]] = None,
            num_runs=1,
            pairwise_distance_block_size: int = 1024,
            threshold: float = 0.0,
            verbose: bool = False
    ):
        """
        dataset: Dataset
            Original dataset for contrastive learning. Assumes that dataset[i] returns a list of augmented views of the original example i.

        subset_fraction: float
            Fractional size of subset.

        num_downstream_classes: int
            Number of downstream classes (can be an estimate).

        proxy_model: nn.Module
            Proxy model to calculate the augmentation distance (and kmeans clustering if the avoid clip option is chosen).

        augmentation_distance: Dict[int, np.array]
            Pass a precomputed dictionary containing augmentation distance for each latent class.

        num_augmentations: int
            Number of augmentations to consider while approximating the augmentation distance.

        pairwise_distance_block_size: int
            Block size for calculating pairwise distance. This is just to optimize GPU usage while calculating pairwise distance and will not affect the subset created in any way.

        verbose: boolean
            Verbosity of the output.
        """
        super().__init__(
            dataset=dataset,
            subset_fraction=subset_fraction,
            verbose=verbose
        )
        self.device = device
        self.num_downstream_classes = num_downstream_classes
        self.proxy_model = proxy_model
        self.partition = approx_latent_class_partition
        self.augmentation_distance = augmentation_distance
        self.num_runs = num_runs
        self.pairwise_distance_block_size = pairwise_distance_block_size

        if self.augmentation_distance is None:
            self.augmentation_distance = self.approximate_augmentation_distance()

        class_wise_idx = {}
        for latent_class in tqdm(self.partition.keys(), desc="Subset Selection:", disable=not verbose):
            F = SubsetSelectionObjective(self.augmentation_distance[latent_class].copy(), threshold=threshold)
            class_wise_idx[latent_class] = lazy_greedy(F, range(len(self.augmentation_distance[latent_class])),
                                                       len(self.augmentation_distance[latent_class]))
            class_wise_idx[latent_class] = [self.partition[latent_class][i] for i in class_wise_idx[latent_class]]

        self.subset_indices = []
        for latent_class in class_wise_idx.keys():
            l = len(class_wise_idx[latent_class])
            self.subset_indices.extend(class_wise_idx[latent_class][:int(self.subset_fraction * l)])

        self.initialization_complete()

    def approximate_augmentation_distance(self):
        self.proxy_model = self.proxy_model.to(self.device)

        # Initialize augmentation distance with all 0s
        augmentation_distance = {}
        Z = self.encode_trainset()
        for latent_class in self.partition.keys():
            Z_partition = Z[self.partition[latent_class]]
            pairwise_distance = SASSubsetTextDataset.pairwise_distance(Z_partition, Z_partition)
            augmentation_distance[latent_class] = pairwise_distance.copy()
        return augmentation_distance

    def encode_trainset(self):
        trainloader = IMDbDataLoader(self.dataset, batch_size=self.pairwise_distance_block_size)

        with torch.no_grad():
            Z = []
            for input in trainloader:
                text, text_lengths = input.text
                Z.append(self.proxy_model(text.to(self.device), text_lengths.to(self.device)))
        return torch.cat(Z, dim=0)

    def encode_augmented_trainset(self, num_positives=1):
        trainloader = IMDbDataLoader(self.dataset, batch_size=self.pairwise_distance_block_size)

        with torch.no_grad():
            Z = []

            for input in trainloader:
                text, text_lengths = input.text
                Z.append(self.proxy_model(text.to(self.device), text_lengths.to(self.device)))
            Z = torch.stack(Z)

            aug_z = []
            idxs = torch.arange(0, len(Z), num_positives)
            for i in range(num_positives):
                aug_z.append(Z[idxs + i])

            Z = torch.cat(aug_z, dim=0)
        return Z

    @staticmethod
    def pairwise_distance(Z1: torch.tensor, Z2: torch.tensor, block_size: int = 1024):
        similarity_matrices = []
        for i in range(Z1.shape[0] // block_size + 1):
            similarity_matrices_i = []
            e = Z1[i * block_size:(i + 1) * block_size]
            for j in range(Z2.shape[0] // block_size + 1):
                e_t = Z2[j * block_size:(j + 1) * block_size].t()
                similarity_matrices_i.append(
                    np.array(
                        torch.cosine_similarity(e[:, :, None], e_t[None, :, :]).detach().cpu()
                    )
                )
            similarity_matrices.append(similarity_matrices_i)
        similarity_matrix = np.block(similarity_matrices)

        return similarity_matrix

In [126]:
import torch 
from sas.subset_dataset import SASSubsetDataset
net = torch.load("2023-12-0317:57:33.875617-imdb-LSTM-99-net.pt")
critic = torch.load("2023-12-0317:57:33.875617-imdb-LSTM-99-critic.pt")
proxy_model = ProxyModel(net, critic)
from data_proc import NLPDataLoader
     
subset_dataset = SASSubsetTextDataset(
    dataset=imdb_train_dataset,
    subset_fraction=0.2,
    num_downstream_classes=2,
    device=device,
    proxy_model=proxy_model,
    approx_latent_class_partition=partition,
    verbose=True
)

Subset Selection:: 100%|██████████████████████████| 2/2 [00:06<00:00,  3.23s/it]

Subset Size: 4000
Discarded 16000 examples





In [128]:
subset_dataset.subset_indices

[8472,
 14893,
 2441,
 9287,
 4128,
 9733,
 6734,
 3756,
 10750,
 5381,
 11370,
 14528,
 5011,
 8044,
 1009,
 6437,
 9895,
 15517,
 3662,
 16773,
 16960,
 19142,
 2366,
 6584,
 612,
 16492,
 2273,
 2153,
 352,
 5307,
 14809,
 5467,
 9310,
 3739,
 19251,
 8382,
 3999,
 14050,
 8120,
 19840,
 2623,
 13325,
 14108,
 13882,
 19774,
 12487,
 3031,
 8783,
 5702,
 18406,
 13661,
 14383,
 3590,
 297,
 8994,
 6536,
 10220,
 6939,
 9346,
 3534,
 18232,
 7234,
 5494,
 8773,
 13392,
 13361,
 9649,
 7227,
 7818,
 7297,
 8396,
 8812,
 13363,
 11936,
 2148,
 14019,
 7158,
 9240,
 16081,
 116,
 6411,
 10525,
 8466,
 2847,
 1672,
 1546,
 9799,
 711,
 11099,
 8751,
 13518,
 8404,
 686,
 19820,
 1357,
 18297,
 4207,
 7207,
 4424,
 15851,
 3086,
 6361,
 12913,
 16329,
 1291,
 443,
 6782,
 18708,
 2682,
 18313,
 3255,
 11058,
 17051,
 17041,
 12497,
 7049,
 1614,
 13943,
 15459,
 19240,
 10506,
 9604,
 5478,
 11688,
 15923,
 15820,
 12976,
 12152,
 16733,
 8320,
 12130,
 17530,
 6470,
 7232,
 18166,
 19813

## TEST on 2 Layer Linear

In [39]:
dataloader = DataLoader(imdb_train_dataset, batch_size=64,
                        shuffle=True, num_workers=2)

In [40]:
import torch
import torch.nn as nn
import torch.optim as optim

In [41]:
class SimpleClassifier(nn.Module):
    def __init__(self, embedding_dim, output_dim):
        super(SimpleClassifier, self).__init__()
        self.fc1 = nn.Linear(embedding_dim, 128)
        self.fc2 = nn.Linear(128, output_dim)

    def forward(self, embedded):
        x = torch.relu(self.fc1(embedded))
        output = self.fc2(x)
        return output

In [42]:
output_dim = 2  # 二分类问题，输出维度为2
embedding_dim = 100
model = SimpleClassifier(embedding_dim, output_dim)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [43]:
epochs = 5
for epoch in range(epochs):
    for batch in tqdm(dataloader):
        embedding, labels = batch
        labels = labels.to(torch.long)
        optimizer.zero_grad()
        predictions = model(embedding)
        loss = criterion(predictions, labels)
        loss.backward()
        optimizer.step()


  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

  0%|          | 0/313 [00:00<?, ?it/s]

In [44]:
imdb_test_dataset = IMDbDataset(split='test')

In [45]:
test_loader = DataLoader(imdb_test_dataset, batch_size=64,
                        shuffle=True, num_workers=2)

In [46]:
# 在测试集上评估模型
correct = 0
total = 0
with torch.no_grad():
    for batch in test_loader:
        embedding, labels = batch
        labels = labels.to(torch.long)
        predictions = model(embedding)
        _, predicted = torch.max(predictions.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

In [47]:
accuracy = correct / total
print(f'Test Accuracy: {accuracy * 100:.2f}%')

Test Accuracy: 77.79%


In [124]:
loader = IMDbDataLoader(imdb_train_dataset, batch_size=64)

In [158]:
sentences = []
for i in range(len(imdb_train_dataset)):
    text = imdb_train_dataset.__getitem__(i).text
    word_vectors = [glove[word.lower()] for word in text if word in glove.stoi]
    sentences.append(torch.stack(word_vectors).mean(0))

In [161]:
X.shape

torch.Size([20000, 100])