In [12]:
from sklearn.datasets import fetch_20newsgroups
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.feature_extraction import stop_words
from sklearn.model_selection import train_test_split
from string import punctuation
from collections import Counter
import matplotlib.pyplot as plt
import numpy as np
stop_words = list(stop_words.ENGLISH_STOP_WORDS) + list(punctuation)

import torchvision.models as models
import torch

In [2]:
hparams = {
    'n_features' : 10000
}

### 20newgroups

In [17]:
train = fetch_20newsgroups(subset='train')
test = fetch_20newsgroups(subset='test')

train_data, val_data = train_test_split(list(zip(train.data, train.target)), test_size=0.1, random_state=0)
train_texts, train_targets = zip(*train_data)
val_texts, val_targets = zip(*val_data)

tfidf = TfidfVectorizer(max_features=hparams['n_features'], stop_words=stop_words)
train_features = tfidf.fit_transform(train_texts).toarray().astype(np.float32)
val_features = tfidf.transform(val_texts).toarray().astype(np.float32)
test_features = tfidf.transform(test.data).toarray().astype(np.float32)

class_counts = Counter(train_targets)
class_weights = np.array([max(class_counts.values())/class_counts[i] for i in range(len(class_counts))])

label_dict = dict([(k, v) for k, v in enumerate(train.target_names)])

### CIFAR-100

In [28]:
import subprocess
import os

def unpickle(file):
    import pickle
    with open(file, 'rb') as fo:
        dict = pickle.load(fo, encoding='bytes')
    return dict

class Cifar100:
    def __init__(self, val_split=None):
        if not os.path.exists('cifar-100-python'):
            if not os.path.exists('cifar-100-python.tar.gz'):
                process = subprocess.Popen('wget https://www.cs.toronto.edu/~kriz/cifar-100-python.tar.gz'.split(), stdout=subprocess.PIPE)
                output, error = process.communicate()
                if error: print(error)
            process = subprocess.Popen('tar -zxf cifar-100-python.tar.gz'.split(), stdout=subprocess.PIPE)
            output, error = process.communicate()
            if error: print(error)
        
        self.train_images, self.train_flat_labels, self.train_hier_labels = self._parse_data(unpickle('cifar-100-python/train'))
        self.test_images, self.test_flat_labels, self.test_hier_labels = self._parse_data(unpickle('cifar-100-python/test'))
        
        # create label dictionaries
        unique_hier_labels = sorted(set(self.train_hier_labels), key=lambda x : x[1])
        
        cifar_meta = unpickle('cifar-100-python/meta')
        cifar_meta['coarse_label_names'] = [x.decode('utf8') for x in cifar_meta[b'coarse_label_names']]
        cifar_meta['fine_label_names'] = [x.decode('utf8') for x in cifar_meta[b'fine_label_names']]
        
        self.flat_label_dict = dict([(y, cifar_meta['coarse_label_names'][x] + '.' + cifar_meta['fine_label_names'][y]) for x, y in unique_hier_labels])
        self.hier_label_dict = [dict(enumerate(cifar_meta['coarse_label_names'])), dict(enumerate(cifar_meta['fine_label_names']))]
        
        # extract features using resnet
        resnet = models.resnet152(pretrained=True)
        self.resnet_vectorizer = torch.nn.Sequential(*(list(resnet.children())[:-1]))
        
        self.train_features = self.extract_resnet_features(self.train_images)
        self.test_features = self.extract_resnet_features(self.test_images)
        
    def _parse_data(self, data_obj):
        # extract images
        images = []
        for arr in data_obj[b'data']:
            img = np.zeros((32, 32, 3), dtype=np.uint8)
            img[..., 0] = np.reshape(arr[:1024], (32, 32))
            img[..., 1] = np.reshape(arr[1024:2048], (32, 32))
            img[..., 2] = np.reshape(arr[2048:], (32, 32))
            images.append(img)
        
        # extract labels
        flat_labels, hier_labels = [], []
        for coarse_label, fine_label in zip(data_obj[b'coarse_labels'], data_obj[b'fine_labels']):
            flat_labels.append(fine_label)
            hier_labels.append((coarse_label, fine_label))
        
        return images, flat_labels, hier_labels
    
    def extract_resnet_features(self, images, batch_size=16):
        features = []
        for i in range(int(np.ceil(len(images)/batch_size))):
            batch = np.transpose(images[i*batch_size:(i+1)*batch_size], (0, 3, 1, 2))
            features.append(self.resnet_vectorizer(torch.tensor(batch, dtype=torch.float32)))
        features = np.concatenate(features, axis=0).squeeze(-1).squeeze(-1)
        return features

In [None]:
cifar100 = Cifar100()

Downloading: "https://download.pytorch.org/models/resnet152-b121ed2d.pth" to /home/ec2-user/.cache/torch/checkpoints/resnet152-b121ed2d.pth
100%|██████████| 230M/230M [00:10<00:00, 22.3MB/s] 


### Logistic Regression

In [18]:
from sklearn.linear_model import LogisticRegression

In [19]:
clf = LogisticRegression(class_weight='balanced')

In [20]:
clf.fit(train_features, train_targets)



LogisticRegression(C=1.0, class_weight='balanced', dual=False,
          fit_intercept=True, intercept_scaling=1, max_iter=100,
          multi_class='warn', n_jobs=None, penalty='l2', random_state=None,
          solver='warn', tol=0.0001, verbose=0, warm_start=False)

In [21]:
clf.score(train_features, train_targets)

0.9643133355430339

In [22]:
clf.score(val_features, val_targets)

0.8824569155987627

In [23]:
clf.score(test_features, test.target)

0.8068242166755177