In [43]:
from sage.all import *
from sage.rings.integer import Integer
import functools

In [44]:
class ClassList:
    # DO NOT USE THE FOLLOWING METHODS
    def __init__(self, c_list):
        self.n = len(c_list) - 1  # corresponding dimension of projective space
        self.c_list = list(c_list)
        self.invalid_class_list = False

    def empty_list(self):
        return ClassList([0]*(self.n+1))

    def __add__(self, other):
        if isinstance(other, type(self)):
            return ClassList([c1 + c2 for c1, c2 in zip(self.c_list, other.c_list)])
        elif isinstance(other, list) or isinstance(other, tuple):
            # assuming adding list of classes of the form
            # [1,1,1,2,2,2,...]
            new_list = list(self.c_list)
            for c in other:
                if not (c >= 0 and c <= self.n):
                    new_c_list = self.empty_list()
                    new_c_list.invalid_class_list = True
                    return new_c_list
                new_list[c] += 1
            return ClassList(new_list)
        else:
            raise TypeError("Invalid argument type: {}".format(type(other)))


    def __radd__(self, other):
        return self.__add__(other)
    
    def __getitem__(self, key):
        if isinstance(key, (Integer, int)): # check if given sage integer as index
            if key < 0:
                key = self.__len__() + key
            curr_key = 0

            for c, count in enumerate(self.c_list):
                curr_key += count
                if curr_key > key and count > 0:
                    return c

            raise IndexError("List index out of range")

        elif isinstance(key, slice):
            # assume given a slice object in this case
            # a slice object has attribute start, stop, step
            # not handling step for now
            start, stop, step = key.indices(len(self))
            
            new_c_list = [0] * (self.n+1) #list(self.c_list)
            curr_key = 0

            for c, count in enumerate(self.c_list):
                if count == 0:
                    continue

                curr_key += count

                if curr_key > start:
                    if curr_key - start <= count:
                        new_c_list[c] = curr_key - start
                    else:
                        new_c_list[c] = count
                    
                if curr_key >= stop:
                    if (curr_key - start) <= count:
                        new_c_list[c] = stop - start
                    else:
                        new_c_list[c] = count - (curr_key - stop)
                    break

            return ClassList(new_c_list)
                    
        else:
            raise TypeError("Invalid argument type: {}".format(type(key)))

    def __setitem__(self, key, item):
        assert False, "call __setitem__"
        curr_key = 0
        for c, count in self.c_list:
            if curr_key > key:
                if count > 0:
                    self.c_list[c] -= 1
                self.c_list[item] += 1
                return
            curr_key += count
        raise IndexError("List index out of range")

    def __delitem__(self, key):
        assert False, "call __delitem__"
        print("delete itrm")
        curr_key = 0
        for c, count in self.c_list:
            if curr_key > key:
                if count > 0:
                    self.c_list[c] -= 1
                else:
                    return
            curr_key += count
        raise IndexError("List index out of range")

    def __iter__(self):
        return self.decompress().__iter__()

    def __hash__(self):
        return tuple(self.c_list).__hash__()

    def __eq__(self, other):
        if isinstance(other, list):
            return ClassList.compress(other, self.n).c_list == self.c_list
        elif isinstance(other, ClassList):
            return other.c_list == self.c_list
        else:
            raise TypeError(f"Invalid argument type: {type(other)}")

    def __len__(self):
        return sum(self.c_list)

    def __str__(self):
        return str(self.c_list)
    
    # YOU MAY USE THE FOLLOWING METHODS


    # deleting by index of class list of the form [1,1,1,2,2,2...]
    # return a new ClassList with the class deleted
    def remove_class(self, c):
        new_c_list = list(self.c_list)
        new_c_list[c] -= 1
        return ClassList(new_c_list)
    
    # find the minimum class
    def min(self):
        for c, count in enumerate(self.c_list):
            if count > 0:
                return c
        return -1
    
    @staticmethod
    def compress(l, n):
        c_list = [0]*(n+1) 
        for c in l:
            c_list[c] += 1
        return ClassList(c_list)
        
    def decompress(self):
        return ClassList.decompress_(self.c_list)

    @staticmethod
    def decompress_(c_list):
        l = []
        for c, count in enumerate(c_list):
            l += [c]*count
        return l

    def bipart(self):
        return [[ClassList(t), \
            ClassList([self.c_list[i] - t[i] for i in range(0, self.n+1)])] \
            for t in ClassList.sublist(self.c_list)]

    @staticmethod
    def bipart_coeff(A1, A2):
        assert A1.n == A2.n
        return product([binomial(A1.c_list[i], A2.c_list[i]) for i in range(0, A1.n+1)])   

    @staticmethod
    @functools.lru_cache(maxsize=None)
    def sublist_helper(t):
        if t == [] or t == ():
            return []
        tail = t[1:]
        l = []
        if len(tail) == 0:
            for i in range(0, t[0]+1):
                l += [[i]]
        else:
            for i in range(0, t[0]+1):
                for s in ClassList.sublist(tail):
                    l += [[i] + s]
        return l
        
    @staticmethod
    def sublist(t):
        return ClassList.sublist_helper(tuple(t))
